!16035 Add Server part 3

From: @zpac
Reviewed-by: @cristoval
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-05-08 09:20:47 +08:00 committed by Gitee
commit dcec57955c
34 changed files with 2379 additions and 27 deletions

View File

@ -47,6 +47,7 @@
#include "ps/parameter_server.h"
#include "ps/scheduler.h"
#include "ps/worker.h"
#include "ps/server/server.h"
#endif
namespace mindspore {
@ -619,6 +620,47 @@ bool StartPSServerAction(const ResourcePtr &res) {
return true;
}
bool StartServerAction(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
const std::string &server_mode_ = ps::PSContext::instance()->server_mode();
size_t worker_num = ps::PSContext::instance()->initial_worker_num();
size_t server_num = ps::PSContext::instance()->initial_server_num();
uint64_t fl_server_port = ps::PSContext::instance()->fl_server_port();
// Update model threshold is a certain ratio of start_fl_job threshold.
// update_model_threshold_ = start_fl_job_threshold_ * percent_for_update_model_.
size_t start_fl_job_threshold = ps::PSContext::instance()->start_fl_job_threshold();
float percent_for_update_model = 1;
size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * percent_for_update_model));
std::vector<ps::server::RoundConfig> rounds_config = {
{"startFLJob", false, 3000, false, start_fl_job_threshold},
{"updateModel", false, 3000, false, update_model_threshold},
{"getModel", false, 3000},
{"asyncUpdateModel"},
{"asyncGetModel"},
{"push", false, 3000, true, worker_num},
{"pull", false, 3000, true, worker_num},
{"getWeightsByKey", false, 3000, true, 1},
{"overwriteWeightsByKey", false, 3000, true, server_num},
};
size_t executor_threshold = 0;
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
executor_threshold = update_model_threshold;
ps::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, func_graph,
executor_threshold);
} else if (server_mode_ == ps::kServerModePS) {
executor_threshold = worker_num;
ps::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, func_graph, executor_threshold);
} else {
MS_LOG(EXCEPTION) << "Server mode " << server_mode_ << " is not supported.";
return false;
}
ps::server::Server::GetInstance().Run();
return true;
}
bool StartPSSchedulerAction(const ResourcePtr &res) {
ps::Scheduler::GetInstance().Run();
return true;
@ -797,6 +839,14 @@ std::vector<ActionItem> VmPipeline() {
}
#if (ENABLE_CPU && !_WIN32)
std::vector<ActionItem> ServerPipeline() {
auto actions = CommonPipeline();
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
actions.emplace_back(std::make_pair("validate", ValidateAction));
actions.emplace_back(std::make_pair("server", StartServerAction));
return actions;
}
std::vector<ActionItem> PServerPipeline() {
auto actions = CommonPipeline();
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));

View File

@ -43,10 +43,14 @@ bool ExecuteAction(const ResourcePtr &res);
bool StartPSWorkerAction(const ResourcePtr &res);
bool StartPSServerAction(const ResourcePtr &res);
bool StartPSSchedulerAction(const ResourcePtr &res);
// This action is only for federated learning only. In later version, parameter server mode and federated learning will
// use the same action.
bool StartServerAction(const ResourcePtr &res);
std::vector<ActionItem> GePipeline();
std::vector<ActionItem> VmPipeline();
std::vector<ActionItem> PServerPipeline();
std::vector<ActionItem> ServerPipeline();
std::vector<ActionItem> PSchedulerPipeline();
abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
const abstract::AbstractBasePtrList &args_spec, bool clear = false);

View File

@ -326,7 +326,24 @@ PYBIND11_MODULE(_c_expression, m) {
.def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.")
.def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.")
.def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.")
.def("set_rank_id", &PSContext::set_rank_id, "Set rank id for worker on ps mode.");
.def("set_rank_id", &PSContext::set_rank_id, "Set rank id for worker on ps mode.")
.def("set_server_mode", &PSContext::set_server_mode, "Set server mode.")
.def("server_mode", &PSContext::server_mode, "Get server mode.")
.def("set_ms_role", &PSContext::set_ms_role, "Set role for this process.")
.def("ms_role", &PSContext::ms_role, "Get role for this process.")
.def("set_worker_num", &PSContext::set_worker_num, "Set worker number.")
.def("set_server_num", &PSContext::set_server_num, "Set server number.")
.def("set_scheduler_ip", &PSContext::set_scheduler_ip, "Set scheduler ip.")
.def("set_scheduler_port", &PSContext::set_scheduler_port, "Set scheduler port.")
.def("set_fl_server_port", &PSContext::set_fl_server_port, "Set federated learning server port.")
.def("set_fl_client_enable", &PSContext::set_fl_client_enable, "Set federated learning client.")
.def("set_start_fl_job_threshold", &PSContext::set_start_fl_job_threshold, "Set threshold count for start_fl_job.")
.def("set_fl_name", &PSContext::set_fl_name, "Set federated learning name.")
.def("set_fl_iteration_num", &PSContext::set_fl_iteration_num, "Set federated learning iteration number.")
.def("set_client_epoch_num", &PSContext::set_client_epoch_num, "Set federated learning client epoch number.")
.def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.")
.def("set_secure_aggregation", &PSContext::set_secure_aggregation,
"Set federated learning client using secure aggregation.");
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
.def(py::init())

View File

@ -55,6 +55,7 @@
#include "ps/worker.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "ps/ps_cache/ps_cache_manager.h"
#include "ps/server/server.h"
#endif
#if (ENABLE_GE || ENABLE_D)
@ -529,6 +530,11 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
std::string backend = MsContext::GetInstance()->backend_policy();
#if (ENABLE_CPU && !_WIN32)
const std::string &server_mode = ps::PSContext::instance()->server_mode();
if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) &&
ps::PSContext::instance()->is_server()) {
return ServerPipeline();
}
if (ps::PSContext::instance()->is_server()) {
resource->results()[kBackend] = compile::CreateBackend();
return PServerPipeline();

View File

@ -50,10 +50,13 @@ if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/apply_momentum_kernel.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/aggregation_kernel_factory.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/fed_avg_kernel.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/optimizer_kernel_factory.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel_factory.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/start_fl_job_kernel.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/update_model_kernel.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/get_model_kernel.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc")
@ -67,6 +70,7 @@ if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM _PS_SRC_FILES "server/iteration.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/model_store.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/round.cc")
list(REMOVE_ITEM _PS_SRC_FILES "server/server.cc")
endif()
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc")

View File

@ -61,7 +61,12 @@ void PSContext::SetPSEnable(bool enabled) {
}
}
bool PSContext::is_ps_mode() const { return ps_enabled_; }
bool PSContext::is_ps_mode() const {
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
return true;
}
return ps_enabled_;
}
void PSContext::Reset() {
ps_enabled_ = false;
@ -77,6 +82,9 @@ void PSContext::Reset() {
}
std::string PSContext::ms_role() const {
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
return role_;
}
if (is_worker_) {
return kEnvRoleOfWorker;
} else if (is_pserver_) {
@ -88,11 +96,26 @@ std::string PSContext::ms_role() const {
}
}
bool PSContext::is_worker() const { return is_worker_; }
bool PSContext::is_worker() const {
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
return role_ == kRoleOfWorker;
}
return is_worker_;
}
bool PSContext::is_server() const { return is_pserver_; }
bool PSContext::is_server() const {
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
return role_ == kEnvRoleOfServer;
}
return is_pserver_;
}
bool PSContext::is_scheduler() const { return is_sched_; }
bool PSContext::is_scheduler() const {
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
return role_ == kEnvRoleOfScheduler;
}
return is_sched_;
}
uint32_t PSContext::initial_worker_num() { return worker_num_; }
@ -150,6 +173,94 @@ void PSContext::set_rank_id(int rank_id) const {
#endif
}
void PSContext::set_server_mode(const std::string &server_mode) {
if (server_mode != kServerModePS && server_mode != kServerModeFL && server_mode != kServerModeHybrid) {
MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL
<< " or " << kServerModeHybrid;
return;
}
server_mode_ = server_mode;
}
const std::string &PSContext::server_mode() const { return server_mode_; }
void PSContext::set_ms_role(const std::string &role) {
if (server_mode_ != kServerModeFL && server_mode_ != kServerModeHybrid) {
MS_LOG(EXCEPTION) << "Only federated learning supports to set role by ps context.";
return;
}
if (role != kEnvRoleOfWorker && role != kEnvRoleOfServer && role != kEnvRoleOfScheduler) {
MS_LOG(EXCEPTION) << "ms_role " << role << " is invalid.";
return;
}
role_ = role;
}
void PSContext::set_worker_num(uint32_t worker_num) { worker_num_ = worker_num; }
uint32_t PSContext::worker_num() const { return worker_num_; }
void PSContext::set_server_num(uint32_t server_num) {
if (server_num == 0) {
MS_LOG(EXCEPTION) << "Server number must be greater than 0.";
return;
}
server_num_ = server_num;
}
uint32_t PSContext::server_num() const { return server_num_; }
void PSContext::set_scheduler_ip(const std::string &sched_ip) { scheduler_host_ = sched_ip; }
std::string PSContext::scheduler_ip() const { return scheduler_host_; }
void PSContext::set_scheduler_port(uint16_t sched_port) { scheduler_port_ = sched_port; }
uint16_t PSContext::scheduler_port() const { return scheduler_port_; }
void PSContext::GenerateResetterRound() {
uint32_t binary_server_context = 0;
bool is_parameter_server_mode = false;
bool is_federated_learning_mode = false;
bool is_mixed_training_mode = false;
if (server_mode_ == kServerModePS) {
is_parameter_server_mode = true;
} else if (server_mode_ == kServerModeFL) {
is_federated_learning_mode = true;
} else if (server_mode_ == kServerModeHybrid) {
is_mixed_training_mode = true;
} else {
MS_LOG(EXCEPTION) << server_mode_ << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL
<< " or " << kServerModeHybrid;
return;
}
binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) |
(is_mixed_training_mode << 2) | (secure_aggregation_ << 3) | (worker_overwrite_weights_ << 4);
if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
resetter_round_ = ResetterRound::kNoNeedToReset;
} else {
resetter_round_ = kServerContextToResetRoundMap.at(binary_server_context);
}
MS_LOG(INFO) << "Server context is " << binary_server_context << ". Resetter round is " << resetter_round_;
return;
}
ResetterRound PSContext::resetter_round() const { return resetter_round_; }
void PSContext::set_fl_server_port(uint16_t fl_server_port) { fl_server_port_ = fl_server_port; }
uint16_t PSContext::fl_server_port() const { return fl_server_port_; }
void PSContext::set_fl_client_enable(bool enabled) { fl_client_enable_ = enabled; }
bool PSContext::fl_client_enable() { return fl_client_enable_; }
void PSContext::set_start_fl_job_threshold(size_t start_fl_job_threshold) {
start_fl_job_threshold_ = start_fl_job_threshold;
}
size_t PSContext::start_fl_job_threshold() const { return start_fl_job_threshold_; }
void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
const std::string &PSContext::fl_name() const { return fl_name_; }
@ -165,5 +276,15 @@ uint64_t PSContext::client_epoch_num() const { return client_epoch_num_; }
void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch_size_ = client_batch_size; }
uint64_t PSContext::client_batch_size() const { return client_batch_size_; }
void PSContext::set_worker_overwrite_weights(uint64_t worker_overwrite_weights) {
worker_overwrite_weights_ = worker_overwrite_weights;
}
uint64_t PSContext::worker_overwrite_weights() const { return worker_overwrite_weights_; }
void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; }
bool PSContext::secure_aggregation() const { return secure_aggregation_; }
} // namespace ps
} // namespace mindspore

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_PS_CONTEXT_H_
#define MINDSPORE_CCSRC_PS_CONTEXT_H_
#include <map>
#include <string>
#include <memory>
#include "ps/constants.h"
@ -24,12 +25,32 @@
namespace mindspore {
namespace ps {
constexpr char kServerModePS[] = "PARAMETER_SERVER";
constexpr char kServerModeFL[] = "FEDERATED_LEARNING";
constexpr char kServerModeHybrid[] = "HYBRID_TRAINING";
constexpr char kEnvRole[] = "MS_ROLE";
constexpr char kEnvRoleOfPServer[] = "MS_PSERVER";
constexpr char kEnvRoleOfServer[] = "MS_SERVER";
constexpr char kEnvRoleOfWorker[] = "MS_WORKER";
constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
// Use binary data to represent federated learning server's context so that we can judge which round resets the
// iteration. From right to left, each bit stands for:
// 0: Server is in parameter server mode.
// 1: Server is in federated learning mode.
// 2: Server is in mixed training mode.
// 3: Server enables sucure aggregation.
// 4: Server needs worker to overwrite weights.
// For example: 01010 stands for that the server is in federated learning mode and sucure aggregation is enabled.
enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kWorkerOverwriteWeights };
const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = {
{0b00010, ResetterRound::kUpdateModel},
{0b01010, ResetterRound::kReconstructSeccrets},
{0b11100, ResetterRound::kWorkerOverwriteWeights},
{0b10100, ResetterRound::kWorkerOverwriteWeights},
{0b00100, ResetterRound::kUpdateModel}};
class PSContext {
public:
~PSContext() = default;
@ -60,19 +81,64 @@ class PSContext {
void set_cache_enable(bool cache_enable) const;
void set_rank_id(int rank_id) const;
// Setter and getter for federated learning.
// In new server framework, process role, worker number, server number, scheduler ip and scheduler port should be set
// by ps_context.
void set_server_mode(const std::string &server_mode);
const std::string &server_mode() const;
void set_ms_role(const std::string &role);
void set_worker_num(uint32_t worker_num);
uint32_t worker_num() const;
void set_server_num(uint32_t server_num);
uint32_t server_num() const;
void set_scheduler_ip(const std::string &sched_ip);
std::string scheduler_ip() const;
void set_scheduler_port(uint16_t sched_port);
uint16_t scheduler_port() const;
// Methods federated learning.
// Generate which round should reset the iteration.
void GenerateResetterRound();
ResetterRound resetter_round() const;
void set_fl_server_port(uint16_t fl_server_port);
uint16_t fl_server_port() const;
// Set true if this process is a federated learning worker in cross-silo scenario.
void set_fl_client_enable(bool enabled);
bool fl_client_enable();
void set_start_fl_job_threshold(size_t start_fl_job_threshold);
size_t start_fl_job_threshold() const;
void set_fl_name(const std::string &fl_name);
const std::string &fl_name() const;
// Set the iteration number of the federated learning.
void set_fl_iteration_num(uint64_t fl_iteration_num);
uint64_t fl_iteration_num() const;
// Set the training epoch number of the client.
void set_client_epoch_num(uint64_t client_epoch_num);
uint64_t client_epoch_num() const;
// Set the data batch size of the client.
void set_client_batch_size(uint64_t client_batch_size);
uint64_t client_batch_size() const;
// Set true if worker will overwrite weights on server. Used in hybrid training.
void set_worker_overwrite_weights(uint64_t worker_overwrite_weights);
uint64_t worker_overwrite_weights() const;
// Set true if using secure aggregation for federated learning.
void set_secure_aggregation(bool secure_aggregation);
bool secure_aggregation() const;
private:
PSContext()
: ps_enabled_(false),
@ -94,11 +160,22 @@ class PSContext {
std::string scheduler_host_;
uint16_t scheduler_port_;
std::string role_;
// Members for federated learning.
std::string server_mode_;
ResetterRound resetter_round_;
uint16_t fl_server_port_;
bool fl_client_enable_;
std::string fl_name_;
size_t start_fl_job_threshold_;
uint64_t fl_iteration_num_;
uint64_t client_epoch_num_;
uint64_t client_batch_size_;
bool worker_overwrite_weights_;
// Federated learning security.
bool secure_aggregation_;
};
} // namespace ps
} // namespace mindspore

View File

@ -20,6 +20,9 @@ namespace mindspore {
namespace ps {
void Scheduler::Run() {
MS_LOG(INFO) << "Start scheduler.";
core::ClusterMetadata::instance()->Init(
PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(),
PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port());
scheduler_node_.Start();
scheduler_node_.Finish();
scheduler_node_.Stop();

View File

@ -44,6 +44,14 @@ enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER };
enum CommType { HTTP = 0, TCP };
enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum };
struct RoundConfig {
std::string name;
bool check_timeout = false;
size_t time_window = 3000;
bool check_count = false;
size_t threshold_count = 0;
};
using mindspore::kernel::Address;
using mindspore::kernel::AddressPtr;
using mindspore::kernel::CPUKernel;
@ -73,6 +81,7 @@ using ReuseKernelNodeInfo = std::map<std::string, size_t>;
using UploadData = std::map<std::string, Address>;
constexpr auto kWeight = "weight";
constexpr auto kNewWeight = "new_weight";
constexpr auto kAccumulation = "accum";
constexpr auto kLearningRate = "lr";
constexpr auto kGradient = "grad";
@ -87,6 +96,8 @@ constexpr auto kAdamBeta1 = "beta1";
constexpr auto kAdamBeta2 = "beta2";
constexpr auto kAdamEps = "eps";
constexpr auto kFtrlLinear = "linear";
constexpr auto kDataSize = "data_size";
constexpr auto kNewDataSize = "new_data_size";
// OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
// launched.
@ -137,6 +148,7 @@ constexpr size_t kExecutorMaxTaskNum = 32;
constexpr int kHttpSuccess = 200;
constexpr auto kPBProtocol = "PB";
constexpr auto kFBSProtocol = "FBS";
constexpr auto kFedAvg = "FedAvg";
constexpr auto kAggregationKernelType = "Aggregation";
constexpr auto kOptimizerKernelType = "Optimizer";
constexpr auto kCtxFuncGraph = "FuncGraph";
@ -145,6 +157,8 @@ constexpr auto kCtxDeviceMetas = "device_metas";
constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration";
constexpr auto kCtxUpdateModelClientList = "update_model_client_list";
constexpr auto kCtxUpdateModelClientNum = "update_model_client_num";
constexpr auto kCtxUpdateModelThld = "update_model_threshold";
constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size";
// This macro the current timestamp in milliseconds.
#define CURRENT_TIME_MILLI \

View File

@ -112,19 +112,19 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) {
std::unique_lock<std::mutex> lock(mutex_[name]);
return global_current_count_[name].size() == global_threshold_count_[name];
} else {
CountReachThresholdRequest count_reach_threashold_req;
count_reach_threashold_req.set_name(name);
CountReachThresholdRequest count_reach_threshold_req;
count_reach_threshold_req.set_name(name);
std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(count_reach_threashold_req, counting_server_rank_,
if (!communicator_->SendPbRequest(count_reach_threshold_req, counting_server_rank_,
core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) {
MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name;
return false;
}
CountReachThresholdResponse count_reach_threashold_rsp;
count_reach_threashold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), query_cnt_enough_rsp_msg->size());
return count_reach_threashold_rsp.is_enough();
CountReachThresholdResponse count_reach_threshold_rsp;
count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), query_cnt_enough_rsp_msg->size());
return count_reach_threshold_rsp.is_enough();
}
}
@ -200,9 +200,9 @@ void DistributedCountService::HandleCountReachThresholdRequest(const std::shared
return;
}
CountReachThresholdRequest count_reach_threashold_req;
count_reach_threashold_req.ParseFromArray(message->data(), message->len());
const std::string &name = count_reach_threashold_req.name();
CountReachThresholdRequest count_reach_threshold_req;
count_reach_threshold_req.ParseFromArray(message->data(), message->len());
const std::string &name = count_reach_threshold_req.name();
std::unique_lock<std::mutex> lock(mutex_[name]);
if (global_threshold_count_.count(name) == 0) {
@ -210,10 +210,10 @@ void DistributedCountService::HandleCountReachThresholdRequest(const std::shared
return;
}
CountReachThresholdResponse count_reach_threashold_rsp;
count_reach_threashold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]);
communicator_->SendResponse(count_reach_threashold_rsp.SerializeAsString().data(),
count_reach_threashold_rsp.SerializeAsString().size(), message);
CountReachThresholdResponse count_reach_threshold_rsp;
count_reach_threshold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]);
communicator_->SendResponse(count_reach_threshold_rsp.SerializeAsString().data(),
count_reach_threshold_rsp.SerializeAsString().size(), message);
return;
}

View File

@ -193,7 +193,29 @@ void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<co
bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) {
std::unique_lock<std::mutex> lock(mutex_[name]);
metadata_[name] = meta;
if (meta.has_device_meta()) {
auto &fl_id_to_meta_map = *metadata_[name].mutable_device_metas()->mutable_fl_id_to_meta();
auto &fl_id = meta.device_meta().fl_id();
auto &device_meta = meta.device_meta();
fl_id_to_meta_map[fl_id] = device_meta;
} else if (meta.has_fl_id()) {
auto client_list = metadata_[name].mutable_client_list();
auto &fl_id = meta.fl_id().fl_id();
// Check whether the new item already exists.
bool add_flag = true;
for (int i = 0; i < client_list->fl_id_size(); i++) {
if (fl_id == client_list->fl_id(i)) {
add_flag = false;
break;
}
}
if (add_flag) {
client_list->add_fl_id(fl_id);
}
} else if (meta.has_update_model_threshold()) {
auto update_model_threshold = metadata_[name].mutable_update_model_threshold();
*update_model_threshold = meta.update_model_threshold();
}
return true;
}
} // namespace server

View File

@ -23,7 +23,7 @@
namespace mindspore {
namespace ps {
namespace server {
void Executor::Init(const FuncGraphPtr &func_graph, size_t aggregation_count) {
void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) {
MS_EXCEPTION_IF_NULL(func_graph);
if (aggregation_count == 0) {
MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0";

View File

@ -43,7 +43,7 @@ class Executor {
// be used for aggregators.
// As noted in header file parameter_aggregator.h, we create aggregators by trainable parameters, which is the
// optimizer cnode's input. So we need to initialize server executor using func_graph.
void Init(const FuncGraphPtr &func_graph, size_t aggregation_count);
void Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count);
// Called in parameter server training mode to do Push operation.
// For the same trainable parameter, HandlePush method must be called aggregation_count_ times before it's considered

View File

@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/kernel/fed_avg_kernel.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
REG_AGGREGATION_KERNEL_TWO(FedAvg,
ParamsInfo()
.AddInputNameType(kWeight, kNumberTypeFloat32)
.AddInputNameType(kDataSize, kNumberTypeUInt64)
.AddInputNameType(kNewWeight, kNumberTypeFloat32)
.AddInputNameType(kNewDataSize, kNumberTypeUInt64),
FedAvgKernel, float, size_t)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,179 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <functional>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "ps/server/common.h"
#include "ps/server/collective_ops_impl.h"
#include "ps/server/distributed_count_service.h"
#include "ps/server/local_meta_store.h"
#include "ps/server/kernel/aggregation_kernel.h"
#include "ps/server/kernel/aggregation_kernel_factory.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
// The implementation for the federated average. We do weighted average for the weights. The uploaded weights from
// FL-clients is already multiplied by its data size so only sum and division are done in this kernel.
// Pay attention that this kernel is the distributed version of federated average, which means each server node in the
// cluster in invalved in the aggragation process. So the DistributedCountService and CollectiveOpsImpl are called.
template <typename T, typename S>
class FedAvgKernel : public AggregationKernel {
public:
FedAvgKernel() : participated_(false) {}
~FedAvgKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node);
if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 ||
kNameToIdxMap.at(cnode_name).at("inputs").count("weight") == 0) {
MS_LOG(EXCEPTION) << "Can't find index info of weight for kernel " << cnode_name;
return;
}
cnode_weight_idx_ = kNameToIdxMap.at(cnode_name).at("inputs").at("weight");
std::vector<size_t> weight_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, cnode_weight_idx_);
size_t weight_size =
std::accumulate(weight_shape.begin(), weight_shape.end(), sizeof(T), std::multiplies<size_t>());
size_t new_weight_size = weight_size;
input_size_list_.push_back(weight_size);
input_size_list_.push_back(sizeof(size_t));
input_size_list_.push_back(new_weight_size);
input_size_list_.push_back(sizeof(size_t));
auto weight_node =
AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, cnode_weight_idx_), 0).first;
MS_EXCEPTION_IF_NULL(weight_node);
name_ = cnode_name + "." + weight_node->fullname_with_scope();
MS_LOG(INFO) << "Register counter for " << name_;
auto first_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) {
std::unique_lock<std::mutex> lock(weight_mutex_);
if (!participated_) {
ClearWeightAndDataSize();
}
};
auto last_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) {
T *weight_addr = reinterpret_cast<T *>(weight_addr_->addr);
size_t weight_size = weight_addr_->size;
S *data_size_addr = reinterpret_cast<S *>(data_size_addr_->addr);
if (!CollectiveOpsImpl::GetInstance().AllReduce<T>(weight_addr, weight_addr, weight_size / sizeof(T))) {
MS_LOG(ERROR) << "Federated average allreduce failed.";
return;
}
if (!CollectiveOpsImpl::GetInstance().AllReduce<S>(data_size_addr, data_size_addr, 1)) {
MS_LOG(ERROR) << "Federated average allreduce failed.";
return;
}
LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, data_size_addr[0]);
for (size_t i = 0; i < weight_size / sizeof(T); i++) {
weight_addr[i] /= data_size_addr[0];
}
done_ = true;
DistributedCountService::GetInstance().ResetCounter(name_);
return;
};
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler});
return;
}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
std::unique_lock<std::mutex> lock(weight_mutex_);
// The weight and new_weight values should be multiplied by clients already, so we don't need to do multiplication
// again.
T *weight_addr = reinterpret_cast<T *>(inputs[0]->addr);
S *data_size_addr = reinterpret_cast<S *>(inputs[1]->addr);
T *new_weight_addr = reinterpret_cast<T *>(inputs[2]->addr);
S *new_data_size_addr = reinterpret_cast<S *>(inputs[3]->addr);
if (accum_count_ == 0) {
ClearWeightAndDataSize();
}
MS_LOG(DEBUG) << "Iteration: " << LocalMetaStore::GetInstance().curr_iter_num() << " launching FedAvgKernel for "
<< name_ << " new data size is " << new_data_size_addr[0] << ", current total data size is "
<< data_size_addr[0];
for (size_t i = 0; i < inputs[2]->size / sizeof(T); i++) {
weight_addr[i] += new_weight_addr[i];
}
data_size_addr[0] += new_data_size_addr[0];
lock.unlock();
accum_count_++;
participated_ = true;
DistributedCountService::GetInstance().Count(
name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_));
return true;
}
void Reset() {
accum_count_ = 0;
done_ = false;
participated_ = false;
DistributedCountService::GetInstance().ResetCounter(name_);
return;
}
bool IsAggregationDone() { return done_; }
private:
void GenerateReuseKernelNodeInfo() override {
// Only the trainable parameter is reused for federated average.
reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, cnode_weight_idx_));
return;
}
// In some cases, the Launch method is not called and the weights involved in AllReduce should be set to 0.
void ClearWeightAndDataSize() {
int ret = memset_s(weight_addr_->addr, weight_addr_->size, 0x00, weight_addr_->size);
if (ret != 0) {
MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")";
return;
}
ret = memset_s(data_size_addr_->addr, data_size_addr_->size, 0x00, data_size_addr_->size);
if (ret != 0) {
MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")";
return;
}
return;
}
// The trainable parameter index of the kernel node which is parsed from the frontend func_graph.
size_t cnode_weight_idx_;
// The address pointer of the inputs.
AddressPtr weight_addr_;
AddressPtr data_size_addr_;
AddressPtr new_weight_addr_;
AddressPtr new_data_size_addr_;
// Whether the kernel's Launch method is called.
bool participated_;
// The kernel could be called concurrently so we need lock to ensure threadsafe.
std::mutex weight_mutex_;
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_

View File

@ -0,0 +1,125 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/kernel/round/get_model_kernel.h"
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "ps/server/model_store.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
void GetModelKernel::InitKernel(size_t) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
}
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
if (!executor_->initialized()) {
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
return;
}
}
bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
MS_LOG(INFO) << "Launching GetModelKernel kernel.";
void *req_data = inputs[0]->addr;
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
return false;
}
const schema::RequestGetModel *get_model_req = flatbuffers::GetRoot<schema::RequestGetModel>(req_data);
GetModel(get_model_req, fbb);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
bool GetModelKernel::Reset() {
MS_LOG(INFO) << "Get model kernel reset!";
StopTimer();
return true;
}
void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr<FBBuilder> &fbb) {
std::map<std::string, AddressPtr> feature_maps;
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
size_t get_model_iter = static_cast<size_t>(get_model_req->iteration());
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
size_t latest_iter_num = iter_to_model.rbegin()->first;
if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) {
std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter);
BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
MS_LOG(WARNING) << reason;
return;
}
if (iter_to_model.count(get_model_iter) == 0) {
std::string reason = "The iteration of GetModel request" + std::to_string(get_model_iter) +
" is invalid. Current iteration is " + std::to_string(current_iter);
BuildGetModelRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter, feature_maps,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
MS_LOG(ERROR) << reason;
return;
}
feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter);
BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED,
"Get model for iteration " + std::to_string(get_model_iter) + " success.", current_iter,
feature_maps, std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
return;
}
void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, const size_t iter,
const std::map<std::string, AddressPtr> &feature_maps,
const std::string &timestamp) {
auto fbs_reason = fbb->CreateString(reason);
auto fbs_timestamp = fbb->CreateString(timestamp);
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
for (const auto &feature_map : feature_maps) {
auto fbs_weight_fullname = fbb->CreateString(feature_map.first);
auto fbs_weight_data =
fbb->CreateVector(reinterpret_cast<float *>(feature_map.second->addr), feature_map.second->size / sizeof(float));
auto fbs_feature_map = schema::CreateFeatureMap(*(fbb.get()), fbs_weight_fullname, fbs_weight_data);
fbs_feature_maps.push_back(fbs_feature_map);
}
auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
schema::ResponseGetModelBuilder rsp_get_model_builder(*(fbb.get()));
rsp_get_model_builder.add_retcode(retcode);
rsp_get_model_builder.add_reason(fbs_reason);
rsp_get_model_builder.add_iteration(static_cast<int>(iter));
rsp_get_model_builder.add_feature_map(fbs_feature_maps_vector);
rsp_get_model_builder.add_timestamp(fbs_timestamp);
auto rsp_get_model = rsp_get_model_builder.Finish();
fbb->Finish(rsp_get_model);
return;
}
REG_ROUND_KERNEL(getModel, GetModelKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_MODEL_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_MODEL_KERNEL_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "ps/server/common.h"
#include "ps/server/executor.h"
#include "ps/server/kernel/round/round_kernel.h"
#include "ps/server/kernel/round/round_kernel_factory.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
class GetModelKernel : public RoundKernel {
public:
GetModelKernel() = default;
~GetModelKernel() override = default;
void InitKernel(size_t) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
bool Reset() override;
private:
void GetModel(const schema::RequestGetModel *get_model_req, const std::shared_ptr<FBBuilder> &fbb);
void BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, const size_t iter,
const std::map<std::string, AddressPtr> &feature_maps, const std::string &timestamp);
// The executor is for getting model for getModel request.
Executor *executor_;
// The time window of one iteration.
size_t iteration_time_window_;
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_

View File

@ -49,7 +49,7 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
return false;
}
void *req_data = inputs[0]->addr;
const std::shared_ptr<FBBuilder> &fbb = std::make_shared<FBBuilder>();
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
return false;

View File

@ -0,0 +1,203 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "ps/server/kernel/round/update_model_kernel.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
void UpdateModelKernel::InitKernel(size_t threshold_count) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
}
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
if (!executor_->initialized()) {
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
return;
}
PBMetadata client_list;
DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxUpdateModelClientList, client_list);
LocalMetaStore::GetInstance().put_value(kCtxUpdateModelThld, threshold_count);
}
bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
if (inputs.size() != 1 || outputs.size() != 1) {
MS_LOG(ERROR) << "inputs or outputs size is invalid.";
return false;
}
void *req_data = inputs[0]->addr;
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr.";
return false;
}
MS_LOG(INFO) << "Launching UpdateModelKernel kernel.";
if (!ReachThresholdForUpdateModel(fbb)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
}
const schema::RequestUpdateModel *update_model_req = flatbuffers::GetRoot<schema::RequestUpdateModel>(req_data);
if (!UpdateModel(update_model_req, fbb)) {
MS_LOG(ERROR) << "Updating model failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
}
if (!CountForUpdateModel(fbb, update_model_req)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
bool UpdateModelKernel::Reset() {
MS_LOG(INFO) << "Update model kernel reset!";
StopTimer();
DistributedCountService::GetInstance().ResetCounter(name_);
executor_->ResetAggregationStatus();
DistributedMetadataStore::GetInstance().ResetMetadata(kCtxUpdateModelClientList);
size_t &total_data_size = LocalMetaStore::GetInstance().mutable_value<size_t>(kCtxFedAvgTotalDataSize);
total_data_size = 0;
return true;
}
void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
if (PSContext::instance()->resetter_round() == ResetterRound::kUpdateModel) {
while (!executor_->IsAllWeightAggregationDone()) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize);
MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is "
<< total_data_size;
FinishIterCb();
}
}
bool UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
std::string reason = "Current amount for updateModel is enough.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
MS_LOG(ERROR) << reason;
return false;
}
return true;
}
bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
const std::shared_ptr<FBBuilder> &fbb) {
size_t iteration = static_cast<size_t>(update_model_req->iteration());
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num());
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
MS_LOG(ERROR) << reason;
return false;
}
PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas);
FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas();
std::string update_model_fl_id = update_model_req->fl_id()->str();
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
std::string reason = "devices_meta for " + update_model_fl_id + " is not set.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
MS_LOG(ERROR) << reason;
return false;
}
size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size();
auto feature_map = ParseFeatureMap(update_model_req);
for (auto weight : feature_map) {
weight.second[kNewDataSize].addr = &data_size;
weight.second[kNewDataSize].size = sizeof(size_t);
executor_->HandleModelUpdate(weight.first, weight.second);
}
FLId fl_id;
fl_id.set_fl_id(update_model_fl_id);
PBMetadata comm_value;
*comm_value.mutable_fl_id() = fl_id;
DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxUpdateModelClientList, comm_value);
BuildUpdateModelRsp(fbb, schema::ResponseCode_SucNotReady, "success not ready",
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
return true;
}
std::map<std::string, UploadData> UpdateModelKernel::ParseFeatureMap(
const schema::RequestUpdateModel *update_model_req) {
RETURN_IF_NULL(update_model_req, {});
std::map<std::string, UploadData> feature_map;
auto fbs_feature_map = update_model_req->feature_map();
for (size_t i = 0; i < fbs_feature_map->size(); i++) {
std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str();
float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data());
size_t weight_size = fbs_feature_map->Get(i)->data()->size() * sizeof(float);
UploadData upload_data;
upload_data[kNewWeight].addr = weight_data;
upload_data[kNewWeight].size = weight_size;
feature_map[weight_full_name] = upload_data;
}
return feature_map;
}
bool UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestUpdateModel *update_model_req) {
if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) {
std::string reason = "UpdateModel counting failed.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_));
MS_LOG(ERROR) << reason;
return false;
}
return true;
}
void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, const std::string &next_req_time) {
auto fbs_reason = fbb->CreateString(reason);
auto fbs_next_req_time = fbb->CreateString(next_req_time);
schema::ResponseUpdateModelBuilder rsp_update_model_builder(*(fbb.get()));
rsp_update_model_builder.add_retcode(retcode);
rsp_update_model_builder.add_reason(fbs_reason);
rsp_update_model_builder.add_next_req_time(fbs_next_req_time);
auto rsp_update_model = rsp_update_model_builder.Finish();
fbb->Finish(rsp_update_model);
return;
}
REG_ROUND_KERNEL(updateModel, UpdateModelKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,64 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "ps/server/common.h"
#include "ps/server/kernel/round/round_kernel.h"
#include "ps/server/kernel/round/round_kernel_factory.h"
#include "ps/server/executor.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
class UpdateModelKernel : public RoundKernel {
public:
UpdateModelKernel() = default;
~UpdateModelKernel() override = default;
void InitKernel(size_t threshold_count) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
bool Reset() override;
// In some cases, the last updateModel message means this server iteration is finished.
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
private:
bool ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb);
bool UpdateModel(const schema::RequestUpdateModel *update_model_req, const std::shared_ptr<FBBuilder> &fbb);
std::map<std::string, UploadData> ParseFeatureMap(const schema::RequestUpdateModel *update_model_req);
bool CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestUpdateModel *update_model_req);
void BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, const std::string &next_req_time);
// The executor is for updating the model for updateModel request.
Executor *executor_;
// The time window of one iteration.
size_t iteration_time_window_;
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_

View File

@ -23,7 +23,7 @@
namespace mindspore {
namespace ps {
namespace server {
void ModelStore::Init(uint32_t max_count) {
void ModelStore::Initialize(uint32_t max_count) {
if (!Executor::GetInstance().initialized()) {
MS_LOG(EXCEPTION) << "Server's executor must be initialized before model storage.";
return;

View File

@ -40,7 +40,7 @@ class ModelStore {
}
// Initialize ModelStore with max count of models need to be stored.
void Init(uint32_t max_count = 3);
void Initialize(uint32_t max_count = 3);
// Store the model of the given iteration. The model is acquired from Executor. If the current model count is already
// max_model_count_, the earliest model will be replaced.

View File

@ -302,7 +302,7 @@ bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<ke
}
std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &cnode) {
std::vector<std::string> aggregation_algorithm = {};
std::vector<std::string> aggregation_algorithm = {kFedAvg};
MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm;
return aggregation_algorithm;
}

View File

@ -0,0 +1,251 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/server.h"
#include <memory>
#include <string>
#include <csignal>
#include "ps/server/round.h"
#include "ps/server/model_store.h"
#include "ps/server/iteration.h"
#include "ps/server/collective_ops_impl.h"
#include "ps/server/distributed_metadata_store.h"
#include "ps/server/distributed_count_service.h"
#include "ps/server/kernel/round/round_kernel_factory.h"
namespace mindspore {
namespace ps {
namespace server {
static std::vector<std::shared_ptr<core::CommunicatorBase>> global_worker_server_comms = {};
// This function is for the exit of server process when an interrupt signal is captured.
void SignalHandler(int signal) {
MS_LOG(INFO) << "Interrupt signal captured: " << signal;
std::for_each(global_worker_server_comms.begin(), global_worker_server_comms.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
return;
}
void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
const FuncGraphPtr &func_graph, size_t executor_threshold) {
MS_EXCEPTION_IF_NULL(func_graph);
func_graph_ = func_graph;
if (rounds_config.empty()) {
MS_LOG(EXCEPTION) << "Rounds are empty.";
return;
}
rounds_config_ = rounds_config;
use_tcp_ = use_tcp;
use_http_ = use_http;
http_port_ = http_port;
executor_threshold_ = executor_threshold;
return;
}
// Each step of the server pipeline may have dependency on other steps, which includes:
// InitServerContext must be the first step to set contexts for later steps.
// Server Running relies on URL or Message Type Register:
// StartCommunicator---->InitIteration
// Metadata Register relies on Hash Ring of Servers which relies on Network Building Completion:
// RegisterRoundKernel---->StartCommunicator
// Kernel Initialization relies on Executor Initialization:
// RegisterRoundKernel---->InitExecutor
// Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization:
// InitCipher---->InitExecutor
void Server::Run() {
signal(SIGINT, SignalHandler);
InitServerContext();
InitCluster();
InitIteration();
StartCommunicator();
InitExecutor();
RegisterRoundKernel();
MS_LOG(INFO) << "Server started successfully.";
// Wait communicators to stop so the main thread is blocked.
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Join(); });
communicator_with_server_->Join();
MsException::Instance().CheckException();
return;
}
void Server::InitServerContext() {
PSContext::instance()->GenerateResetterRound();
scheduler_ip_ = PSContext::instance()->scheduler_host();
scheduler_port_ = PSContext::instance()->scheduler_port();
worker_num_ = PSContext::instance()->initial_worker_num();
server_num_ = PSContext::instance()->initial_server_num();
return;
}
void Server::InitCluster() {
server_node_ = std::make_shared<core::ServerNode>();
MS_EXCEPTION_IF_NULL(server_node_);
task_executor_ = std::make_shared<core::TaskExecutor>(32);
MS_EXCEPTION_IF_NULL(task_executor_);
if (!InitCommunicatorWithServer()) {
MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed.";
return;
}
if (!InitCommunicatorWithWorker()) {
MS_LOG(EXCEPTION) << "Initializing worker-server communicator failed.";
return;
}
global_worker_server_comms = communicators_with_worker_;
return;
}
bool Server::InitCommunicatorWithServer() {
MS_EXCEPTION_IF_NULL(task_executor_);
MS_EXCEPTION_IF_NULL(server_node_);
communicator_with_server_ =
server_node_->GetOrCreateTcpComm(scheduler_ip_, scheduler_port_, worker_num_, server_num_, task_executor_);
MS_EXCEPTION_IF_NULL(communicator_with_server_);
// Set exception event callbacks for server.
auto tcp_comm = std::dynamic_pointer_cast<core::TcpCommunicator>(communicator_with_server_);
MS_EXCEPTION_IF_NULL(tcp_comm);
tcp_comm->RegisterEventCallback(core::CLUSTER_TIMEOUT, [&]() {
MS_LOG(ERROR) << "Event CLUSTER_TIMEOUT is captured. This is because some nodes(Scheduler/Server/Worker) are not "
"started during network building phase.";
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
communicator_with_server_->Stop();
});
tcp_comm->RegisterEventCallback(core::SCHEDULER_TIMEOUT, [&]() {
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
communicator_with_server_->Stop();
});
tcp_comm->RegisterEventCallback(core::NODE_TIMEOUT, [&]() {
MS_LOG(ERROR)
<< "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
"network building phase.";
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
communicator_with_server_->Stop();
});
return true;
}
bool Server::InitCommunicatorWithWorker() {
MS_EXCEPTION_IF_NULL(server_node_);
MS_EXCEPTION_IF_NULL(task_executor_);
if (!use_tcp_ && !use_http_) {
MS_LOG(EXCEPTION) << "At least one type of protocol should be set.";
return false;
}
if (use_tcp_) {
auto tcp_comm = communicator_with_server_;
MS_EXCEPTION_IF_NULL(tcp_comm);
communicators_with_worker_.push_back(tcp_comm);
}
if (use_http_) {
auto http_comm = server_node_->GetOrCreateHttpComm("0.0.0.0", http_port_, task_executor_);
MS_EXCEPTION_IF_NULL(http_comm);
communicators_with_worker_.push_back(http_comm);
}
return true;
}
void Server::InitIteration() {
iteration_ = std::make_shared<Iteration>();
MS_EXCEPTION_IF_NULL(iteration_);
// 1.Add rounds to the iteration according to the server mode.
for (const RoundConfig &config : rounds_config_) {
std::shared_ptr<Round> round = std::make_shared<Round>(config.name, config.check_timeout, config.time_window,
config.check_count, config.threshold_count);
MS_LOG(INFO) << "Add round " << config.name << ", check_count: " << config.check_count
<< ", threshold:" << config.threshold_count;
iteration_->AddRound(round);
}
// 2.Initialize all the rounds.
TimeOutCb time_out_cb = std::bind(&Iteration::ProceedToNextIter, iteration_);
FinishIterCb finish_iter_cb = std::bind(&Iteration::ProceedToNextIter, iteration_);
iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
return;
}
void Server::InitExecutor() {
if (executor_threshold_ == 0) {
MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0.";
return;
}
// The train engine instance is used in both push-type and pull-type kernels,
// so the required_cnt of these kernels must be the same as update_model_threshold_.
MS_LOG(INFO) << "Required count for push-type and pull-type kernels is " << executor_threshold_;
Executor::GetInstance().Initialize(func_graph_, executor_threshold_);
ModelStore::GetInstance().Initialize();
return;
}
void Server::RegisterRoundKernel() {
MS_EXCEPTION_IF_NULL(iteration_);
auto &rounds = iteration_->rounds();
if (rounds.empty()) {
MS_LOG(EXCEPTION) << "Server has no round registered.";
return;
}
for (auto &round : rounds) {
const std::string &name = round->name();
std::shared_ptr<kernel::RoundKernel> round_kernel = kernel::RoundKernelFactory::GetInstance().Create(name);
if (round_kernel == nullptr) {
MS_LOG(EXCEPTION) << "Round kernel for round " << name << " is not registered.";
return;
}
// For some round kernels, the threshold count should be set.
round_kernel->InitKernel(round->threshold_count());
round->BindRoundKernel(round_kernel);
}
return;
}
void Server::StartCommunicator() {
MS_EXCEPTION_IF_NULL(communicator_with_server_);
if (communicators_with_worker_.empty()) {
MS_LOG(EXCEPTION) << "Communicators for communication with worker is empty.";
return;
}
MS_LOG(INFO) << "Start communicator with server.";
communicator_with_server_->Start();
DistributedMetadataStore::GetInstance().Initialize(server_node_);
CollectiveOpsImpl::GetInstance().Initialize(server_node_);
DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank);
MS_LOG(INFO) << "Start communicator with worker.";
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Start(); });
}
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,131 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_SERVER_H_
#define MINDSPORE_CCSRC_PS_SERVER_SERVER_H_
#include <memory>
#include <string>
#include <vector>
#include "ps/core/communicator/communicator_base.h"
#include "ps/core/communicator/tcp_communicator.h"
#include "ps/core/communicator/task_executor.h"
#include "ps/server/common.h"
#include "ps/server/executor.h"
#include "ps/server/iteration.h"
namespace mindspore {
namespace ps {
namespace server {
// Class Server is the entrance of MindSpore's parameter server training mode and federated learning.
class Server {
public:
static Server &GetInstance() {
static Server instance;
return instance;
}
void Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
const FuncGraphPtr &func_graph, size_t executor_threshold);
// According to the current MindSpore framework, method Run is a step of the server pipeline. This method will be
// blocked until the server is finalized.
// func_graph is the frontend graph which will be parse in server's exector and aggregator.
void Run();
private:
Server()
: server_node_(nullptr),
task_executor_(nullptr),
use_tcp_(false),
use_http_(false),
http_port_(0),
func_graph_(nullptr),
executor_threshold_(0),
communicator_with_server_(nullptr),
communicators_with_worker_({}),
iteration_(nullptr),
scheduler_ip_(""),
scheduler_port_(0),
server_num_(0),
worker_num_(0) {}
~Server() = default;
Server(const Server &) = delete;
Server &operator=(const Server &) = delete;
// Load variables which is set by ps_context.
void InitServerContext();
// Initialize the server cluster, server node and communicators.
void InitCluster();
bool InitCommunicatorWithServer();
bool InitCommunicatorWithWorker();
// Initialize iteration with rounds. Which rounds to use could be set by ps_context as well.
void InitIteration();
// Initialize executor according to the server mode.
void InitExecutor();
// Create round kernels and bind these kernels with corresponding Round.
void RegisterRoundKernel();
// The communicators should be started after all initializations are completed.
void StartCommunicator();
// The server node is initialized in Server.
std::shared_ptr<core::ServerNode> server_node_;
// The task executor of the communicators. This helps server to handle network message concurrently. The tasks
// submitted to this task executor is asynchronous.
std::shared_ptr<core::TaskExecutor> task_executor_;
// Which protocol should communicators use.
bool use_tcp_;
bool use_http_;
uint64_t http_port_;
// The configure of all rounds.
std::vector<RoundConfig> rounds_config_;
// The graph passed by the frontend without backend optimizing.
FuncGraphPtr func_graph_;
// The threshold count for executor to do aggregation or optimizing.
size_t executor_threshold_;
// Server need a tcp communicator to communicate with other servers for counting, metadata storing, collective
// operations, etc.
std::shared_ptr<core::CommunicatorBase> communicator_with_server_;
// The communication with workers(including mobile devices), has multiple protocol types: HTTP and TCP.
// In some cases, both types should be supported in one distributed training job. So here we may have multiple
// communicators.
std::vector<std::shared_ptr<core::CommunicatorBase>> communicators_with_worker_;
// Iteration consists of multiple kinds of rounds.
std::shared_ptr<Iteration> iteration_;
// Variables set by ps context.
std::string scheduler_ip_;
uint16_t scheduler_port_;
uint32_t server_num_;
uint32_t worker_num_;
};
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_SERVER_H_

View File

@ -33,7 +33,21 @@ def ps_context():
return _ps_context
_set_ps_context_func_map = {
"enable_ps": ps_context().set_ps_enable
"server_mode": ps_context().set_server_mode,
"ms_role": ps_context().set_ms_role,
"enable_ps": ps_context().set_ps_enable,
"worker_num": ps_context().set_worker_num,
"server_num": ps_context().set_server_num,
"scheduler_ip": ps_context().set_scheduler_ip,
"scheduler_port": ps_context().set_scheduler_port,
"fl_server_port": ps_context().set_fl_server_port,
"enable_fl_client": ps_context().set_fl_client_enable,
"start_fl_job_threshold": ps_context().set_start_fl_job_threshold,
"fl_name": ps_context().set_fl_name,
"fl_iteration_num": ps_context().set_fl_iteration_num,
"client_epoch_num": ps_context().set_client_epoch_num,
"client_batch_size": ps_context().set_client_batch_size,
"secure_aggregation": ps_context().set_secure_aggregation
}
_get_ps_context_func_map = {

View File

@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import subprocess
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Finish test_mobile_lenet.py case")
parser.add_argument("--scheduler_port", type=int, default=8113)
args, _ = parser.parse_known_args()
scheduler_port = args.scheduler_port
cmd = "pid=`ps -ef|grep \"scheduler_port=" + str(scheduler_port) + "\" "
cmd += " | grep -v \"grep\" | grep -v \"finish\" |awk '{print $2}'` && "
cmd += "for id in $pid; do kill -9 $id && echo \"killed $id\"; done"
subprocess.call(['bash', '-c', cmd])

View File

@ -0,0 +1,52 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import subprocess
parser = argparse.ArgumentParser(description="Run test_mobile_lenet.py case")
parser.add_argument("--device_target", type=str, default="CPU")
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
parser.add_argument("--worker_num", type=int, default=0)
parser.add_argument("--server_num", type=int, default=2)
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--fl_server_port", type=int, default=6666)
if __name__ == "__main__":
args, _ = parser.parse_known_args()
device_target = args.device_target
server_mode = args.server_mode
worker_num = args.worker_num
server_num = args.server_num
scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
fl_server_port = args.fl_server_port
cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&"
cmd_sched += "mkdir ${execute_path}/scheduler/ &&"
cmd_sched += "cd ${execute_path}/scheduler/ || exit && export GLOG_v=1 &&"
cmd_sched += "python ${self_path}/../test_mobile_lenet.py"
cmd_sched += " --device_target=" + device_target
cmd_sched += " --server_mode=" + server_mode
cmd_sched += " --ms_role=MS_SCHED"
cmd_sched += " --worker_num=" + str(worker_num)
cmd_sched += " --server_num=" + str(server_num)
cmd_sched += " --scheduler_ip=" + scheduler_ip
cmd_sched += " --scheduler_port=" + str(scheduler_port)
cmd_sched += " --fl_server_port=" + str(fl_server_port)
cmd_sched += " > scheduler.log 2>&1 &"
subprocess.call(['bash', '-c', cmd_sched])

View File

@ -0,0 +1,82 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import ast
import argparse
import subprocess
parser = argparse.ArgumentParser(description="Run test_mobile_lenet.py case")
parser.add_argument("--device_target", type=str, default="CPU")
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
parser.add_argument("--worker_num", type=int, default=0)
parser.add_argument("--server_num", type=int, default=2)
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--fl_server_port", type=int, default=6666)
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
parser.add_argument("--fl_name", type=str, default="Lenet")
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
parser.add_argument("--local_server_num", type=int, default=-1)
if __name__ == "__main__":
args, _ = parser.parse_known_args()
device_target = args.device_target
server_mode = args.server_mode
worker_num = args.worker_num
server_num = args.server_num
scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
fl_server_port = args.fl_server_port
start_fl_job_threshold = args.start_fl_job_threshold
fl_name = args.fl_name
fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
client_batch_size = args.client_batch_size
secure_aggregation = args.secure_aggregation
local_server_num = args.local_server_num
if local_server_num == -1:
local_server_num = server_num
assert local_server_num <= server_num, "The local server number should not be bigger than total server number."
for i in range(local_server_num):
cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && "
cmd_server += "rm -rf ${execute_path}/server_" + str(i) + "/ &&"
cmd_server += "mkdir ${execute_path}/server_" + str(i) + "/ &&"
cmd_server += "cd ${execute_path}/server_" + str(i) + "/ || exit && export GLOG_v=1 &&"
cmd_server += "python ${self_path}/../test_mobile_lenet.py"
cmd_server += " --device_target=" + device_target
cmd_server += " --server_mode=" + server_mode
cmd_server += " --ms_role=MS_SERVER"
cmd_server += " --worker_num=" + str(worker_num)
cmd_server += " --server_num=" + str(server_num)
cmd_server += " --scheduler_ip=" + scheduler_ip
cmd_server += " --scheduler_port=" + str(scheduler_port)
cmd_server += " --fl_server_port=" + str(fl_server_port + i)
cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold)
cmd_server += " --fl_name=" + fl_name
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
cmd_server += " --client_batch_size=" + str(client_batch_size)
cmd_server += " --secure_aggregation=" + str(secure_aggregation)
cmd_server += " > server.log 2>&1 &"
import time
time.sleep(0.3)
subprocess.call(['bash', '-c', cmd_server])

View File

@ -0,0 +1,29 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export PYTHONPATH=../../../../:$PYTHONPATH
server_num=$1
worker_num=$2
ip=$3
port=$4
for((i=0;i<worker_num;i++));
do
ofs=`expr $i % $server_num`
real_port=`expr $port + $ofs`
echo $real_port
python simulator.py --pid=$i --http_ip=$ip --http_port=$port --use_elb=True --server_num=$1 > simulator_$i.log 2>&1 &
done

View File

@ -0,0 +1,191 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import time
import random
import sys
import requests
import flatbuffers
import numpy as np
from mindspore.schema import (RequestFLJob, ResponseFLJob, ResponseCode,
RequestUpdateModel, FeatureMap, RequestGetModel, ResponseGetModel)
parser = argparse.ArgumentParser()
parser.add_argument("--pid", type=int, default=0)
parser.add_argument("--http_ip", type=str, default="10.113.216.106")
parser.add_argument("--http_port", type=int, default=6666)
parser.add_argument("--use_elb", type=bool, default=False)
parser.add_argument("--server_num", type=int, default=1)
args, _ = parser.parse_known_args()
pid = args.pid
http_ip = args.http_ip
http_port = args.http_port
use_elb = args.use_elb
server_num = args.server_num
str_fl_id = 'fl_lenet_' + str(pid)
def generate_port():
if not use_elb:
return http_port
port = random.randint(0, 100000) % server_num + http_port
return port
def build_start_fl_job(iteration):
start_fl_job_builder = flatbuffers.Builder(1024)
fl_name = start_fl_job_builder.CreateString('fl_test_job')
fl_id = start_fl_job_builder.CreateString(str_fl_id)
data_size = 32
timestamp = start_fl_job_builder.CreateString('2020/11/16/19/18')
RequestFLJob.RequestFLJobStart(start_fl_job_builder)
RequestFLJob.RequestFLJobAddFlName(start_fl_job_builder, fl_name)
RequestFLJob.RequestFLJobAddFlId(start_fl_job_builder, fl_id)
RequestFLJob.RequestFLJobAddIteration(start_fl_job_builder, iteration)
RequestFLJob.RequestFLJobAddDataSize(start_fl_job_builder, data_size)
RequestFLJob.RequestFLJobAddTimestamp(start_fl_job_builder, timestamp)
fl_job_req = RequestFLJob.RequestFLJobEnd(start_fl_job_builder)
start_fl_job_builder.Finish(fl_job_req)
buf = start_fl_job_builder.Output()
return buf
def build_feature_map(builder, names, lengths):
if len(names) != len(lengths):
return None
feature_maps = []
np_data = []
for j, _ in enumerate(names):
name = names[j]
length = lengths[j]
weight_full_name = builder.CreateString(name)
FeatureMap.FeatureMapStartDataVector(builder, length)
weight = np.random.rand(length) * 32
np_data.append(weight)
for idx in range(length - 1, -1, -1):
builder.PrependFloat32(weight[idx])
data = builder.EndVector(length)
FeatureMap.FeatureMapStart(builder)
FeatureMap.FeatureMapAddData(builder, data)
FeatureMap.FeatureMapAddWeightFullname(builder, weight_full_name)
feature_map = FeatureMap.FeatureMapEnd(builder)
feature_maps.append(feature_map)
return feature_maps, np_data
def build_update_model(iteration):
builder_update_model = flatbuffers.Builder(1)
fl_name = builder_update_model.CreateString('fl_test_job')
fl_id = builder_update_model.CreateString(str_fl_id)
timestamp = builder_update_model.CreateString('2020/11/16/19/18')
feature_maps, np_data = build_feature_map(builder_update_model,
["conv1.weight", "conv2.weight", "fc1.weight",
"fc2.weight", "fc3.weight", "fc1.bias", "fc2.bias", "fc3.bias"],
[450, 2400, 48000, 10080, 5208, 120, 84, 62])
RequestUpdateModel.RequestUpdateModelStartFeatureMapVector(builder_update_model, 1)
for single_feature_map in feature_maps:
builder_update_model.PrependUOffsetTRelative(single_feature_map)
feature_map = builder_update_model.EndVector(len(feature_maps))
RequestUpdateModel.RequestUpdateModelStart(builder_update_model)
RequestUpdateModel.RequestUpdateModelAddFlName(builder_update_model, fl_name)
RequestUpdateModel.RequestUpdateModelAddFlId(builder_update_model, fl_id)
RequestUpdateModel.RequestUpdateModelAddIteration(builder_update_model, iteration)
RequestUpdateModel.RequestUpdateModelAddFeatureMap(builder_update_model, feature_map)
RequestUpdateModel.RequestUpdateModelAddTimestamp(builder_update_model, timestamp)
req_update_model = RequestUpdateModel.RequestUpdateModelEnd(builder_update_model)
builder_update_model.Finish(req_update_model)
buf = builder_update_model.Output()
return buf, np_data
def build_get_model(iteration):
builder_get_model = flatbuffers.Builder(1)
fl_name = builder_get_model.CreateString('fl_test_job')
timestamp = builder_get_model.CreateString('2020/12/16/19/18')
RequestGetModel.RequestGetModelStart(builder_get_model)
RequestGetModel.RequestGetModelAddFlName(builder_get_model, fl_name)
RequestGetModel.RequestGetModelAddIteration(builder_get_model, iteration)
RequestGetModel.RequestGetModelAddTimestamp(builder_get_model, timestamp)
req_get_model = RequestGetModel.RequestGetModelEnd(builder_get_model)
builder_get_model.Finish(req_get_model)
buf = builder_get_model.Output()
return buf
weight_name_to_idx = {
"conv1.weight": 0,
"conv2.weight": 1,
"fc1.weight": 2,
"fc2.weight": 3,
"fc3.weight": 4,
"fc1.bias": 5,
"fc2.bias": 6,
"fc3.bias": 7
}
session = requests.Session()
current_iteration = 1
url = "http://" + http_ip + ":" + str(generate_port())
np.random.seed(0)
while True:
url1 = "http://" + http_ip + ":" + str(generate_port()) + '/startFLJob'
print("start url is ", url1)
x = requests.post(url1, data=build_start_fl_job(current_iteration))
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
print("start fl job iteration:", current_iteration, ", id:", args.pid)
while rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED:
x = requests.post(url1, data=build_start_fl_job(current_iteration))
rsp_fl_job = rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
print("epoch is", rsp_fl_job.FlPlanConfig().Epochs())
sys.stdout.flush()
url2 = "http://" + http_ip + ":" + str(generate_port()) + '/updateModel'
print("req update model iteration:", current_iteration, ", id:", args.pid)
update_model_buf, update_model_np_data = build_update_model(current_iteration)
x = session.post(url2, data=update_model_buf)
print("rsp update model iteration:", current_iteration, ", id:", args.pid)
sys.stdout.flush()
url3 = "http://" + http_ip + ":" + str(generate_port()) + '/getModel'
print("req get model iteration:", current_iteration, ", id:", args.pid)
x = session.post(url3, data=build_get_model(current_iteration))
rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
print("rsp get model iteration:", current_iteration, ", id:", args.pid, rsp_get_model.Retcode())
sys.stdout.flush()
repeat_time = 0
while rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady:
time.sleep(0.1)
x = session.post(url3, data=build_get_model(current_iteration))
rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
repeat_time += 1
if repeat_time > 1000:
print("GetModel try timeout ", args.pid)
sys.exit(0)
for i in range(0, 1):
print(rsp_get_model.FeatureMap(i).WeightFullname())
origin = update_model_np_data[weight_name_to_idx[rsp_get_model.FeatureMap(i).WeightFullname().decode('utf-8')]]
after = rsp_get_model.FeatureMap(i).DataAsNumpy() * 32
print("Before update model", args.pid, origin[0:10])
print("After get model", args.pid, after[0:10])
sys.stdout.flush()
assert np.allclose(origin, after, rtol=1e-05, atol=1e-05)
current_iteration += 1

View File

@ -0,0 +1,423 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""AdamWeightDecayForBert, a customized Adam for bert. Input: gradient, overflow flag."""
import numpy as np
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.nn.optim.optimizer import Optimizer
_adam_opt = C.MultitypeFuncGraph("adam_opt")
_scaler_one = Tensor(1, mstype.int32)
_scaler_ten = Tensor(10, mstype.float32)
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
def _update_run_kernel(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter):
"""
Update parameters by AdamWeightDecay op.
"""
if optim_filter:
adam = P.AdamWeightDecay()
if decay_flags:
next_param = adam(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient)
else:
next_param = adam(param, m, v, lr, beta1, beta2, eps, 0.0, gradient)
return next_param
return gradient
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
"""
Update parameters.
Args:
beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
lr (Tensor): Learning rate.
overflow (Tensor): Whether overflow occurs.
weight_decay (Number): Weight decay. Should be equal to or greater than 0.
param (Tensor): Parameters.
m (Tensor): m value of parameters.
v (Tensor): v value of parameters.
gradient (Tensor): Gradient of parameters.
decay_flag (bool): Applies weight decay or not.
optim_filter (bool): Applies parameter update or not.
Returns:
Tensor, the new value of v after updating.
"""
if optim_filter:
op_mul = P.Mul()
op_square = P.Square()
op_sqrt = P.Sqrt()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()
op_select = P.Select()
param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)
cond = op_cast(F.fill(mstype.int32, op_shape(m_fp32), 1) * op_reshape(overflow, (())), mstype.bool_)
next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\
op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32))
next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\
op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32)))
update = next_m / (eps + op_sqrt(next_v))
if decay_flag:
update = op_mul(weight_decay, param_fp32) + update
update_with_lr = op_mul(lr, update)
zeros = F.fill(mstype.float32, op_shape(param_fp32), 0)
next_param = param_fp32 - op_select(cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32)))
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
return op_cast(next_param, F.dtype(param))
return gradient
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable):
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
success = True
indices = gradient.indices
values = gradient.values
if ps_parameter and not cache_enable:
op_shape = P.Shape()
shapes = (op_shape(param), op_shape(m), op_shape(v),
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
eps, values, indices), shapes), param))
return success
if not target:
success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2,
eps, values, indices))
else:
op_mul = P.Mul()
op_square = P.Square()
op_sqrt = P.Sqrt()
scatter_add = P.ScatterAdd(use_locking)
assign_m = F.assign(m, op_mul(beta1, m))
assign_v = F.assign(v, op_mul(beta2, v))
grad_indices = gradient.indices
grad_value = gradient.values
next_m = scatter_add(m,
grad_indices,
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
next_v = scatter_add(v,
grad_indices,
op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value)))
if use_nesterov:
m_temp = next_m * _scaler_ten
assign_m_nesterov = F.assign(m, op_mul(beta1, next_m))
div_value = scatter_add(m,
op_mul(grad_indices, _scaler_one),
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
param_update = div_value / (op_sqrt(next_v) + eps)
m_recover = F.assign(m, m_temp / _scaler_ten)
F.control_depend(m_temp, assign_m_nesterov)
F.control_depend(assign_m_nesterov, div_value)
F.control_depend(param_update, m_recover)
else:
param_update = next_m / (op_sqrt(next_v) + eps)
lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
next_param = param - lr_t * param_update
F.control_depend(assign_m, next_m)
F.control_depend(assign_v, next_v)
success = F.depend(success, F.assign(param, next_param))
success = F.depend(success, F.assign(m, next_m))
success = F.depend(success, F.assign(v, next_v))
return success
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target,
beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param,
moment1, moment2, ps_parameter, cache_enable):
"""Apply adam optimizer to the weight parameter using Tensor."""
success = True
if ps_parameter and not cache_enable:
op_shape = P.Shape()
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
(op_shape(param), op_shape(moment1), op_shape(moment2))), param))
else:
success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
eps, gradient))
return success
@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor")
def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2):
"""Apply AdamOffload optimizer to the weight parameter using Tensor."""
success = True
delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient)
success = F.depend(success, F.assign_add(param, delat_param))
return success
def _check_param_value(beta1, beta2, eps, prim_name):
"""Check the type of inputs."""
validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name)
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
validator.check_positive_float(eps, "eps", prim_name)
class AdamWeightDecayForBert(Optimizer):
"""
Implements the Adam algorithm to fix the weight decay.
Note:
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
To improve parameter groups performance, the customized order of parameters can be supported.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr", "weight_decay" and "order_params" are the keys can be parsed.
- params: Required. The value must be a list of `Parameter`.
- lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
- order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
which in the 'order_params' must be in one of group parameters.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
Default: 1e-3.
beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
Should be in range (0.0, 1.0).
beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
Should be in range (0.0, 1.0).
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
Should be greater than 0.
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
- **overflow** (tuple[Tensor]) - The overflow flag in dynamiclossscale.
Outputs:
tuple[bool], all elements are True.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = AdamWeightDecay(params=net.trainable_params())
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
... {'params': no_conv_params, 'lr': 0.01},
... {'order_params': net.trainable_params()}]
>>> optim = AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
"""
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
super(AdamWeightDecayForBert, self).__init__(learning_rate, params, weight_decay)
_check_param_value(beta1, beta2, eps, self.cls_name)
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
self.eps = Tensor(np.array([eps]).astype(np.float32))
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
self.hyper_map = C.HyperMap()
self.op_select = P.Select()
self.op_cast = P.Cast()
self.op_reshape = P.Reshape()
self.op_shape = P.Shape()
def construct(self, gradients, overflow):
"""AdamWeightDecayForBert"""
lr = self.get_lr()
cond = self.op_cast(F.fill(mstype.int32, self.op_shape(self.beta1), 1) *\
self.op_reshape(overflow, (())), mstype.bool_)
beta1 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta1)
beta2 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta2)
if self.is_group:
if self.is_group_lr:
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(_adam_opt, beta1, beta2, self.eps, lr, overflow),
self.weight_decay, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
if self.use_parallel:
self.broadcast_params(optim_result)
return optim_result
class AdamWeightDecayOp(Optimizer):
"""
Implements the Adam algorithm to fix the weight decay. It is a complete operator, not a combination of other ops.
Note:
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
To improve parameter groups performance, the customized order of parameters can be supported.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr", "weight_decay" and "order_params" are the keys can be parsed.
- params: Required. The value must be a list of `Parameter`.
- lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
- order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
which in the 'order_params' must be in one of group parameters.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
Default: 1e-3.
beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
Should be in range (0.0, 1.0).
beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
Should be in range (0.0, 1.0).
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
Should be greater than 0.
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
tuple[bool], all elements are True.
Supported Platforms:
``GPU``
Examples:
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = AdamWeightDecayOp(params=net.trainable_params())
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
... {'params': no_conv_params, 'lr': 0.01},
... {'order_params': net.trainable_params()}]
>>> optim = AdamWeightDecayOp(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
"""
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay)
_check_param_value(beta1, beta2, eps, self.cls_name)
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
self.eps = Tensor(np.array([eps]).astype(np.float32))
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
self.hyper_map = C.HyperMap()
def construct(self, gradients):
"""AdamWeightDecayOp"""
lr = self.get_lr()
if self.is_group:
if self.is_group_lr:
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr),
self.weight_decay, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
if self.use_parallel:
self.broadcast_params(optim_result)
return optim_result

View File

@ -0,0 +1,72 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
weight_init=weight,
has_bias=False,
pad_mode="valid",
)
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
class LeNet5(nn.Cell):
def __init__(self, num_class=10, channel=3):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = conv(channel, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x

View File

@ -0,0 +1,96 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import ast
import argparse
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
from src.model import LeNet5
from src.adam import AdamWeightDecayOp
parser = argparse.ArgumentParser(description="test_fl_lenet")
parser.add_argument("--device_target", type=str, default="CPU")
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
parser.add_argument("--ms_role", type=str, default="MS_WORKER")
parser.add_argument("--worker_num", type=int, default=0)
parser.add_argument("--server_num", type=int, default=1)
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--fl_server_port", type=int, default=6666)
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
parser.add_argument("--fl_name", type=str, default="Lenet")
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--secure_aggregation", type=ast.literal_eval, default=False)
args, _ = parser.parse_known_args()
device_target = args.device_target
server_mode = args.server_mode
ms_role = args.ms_role
worker_num = args.worker_num
server_num = args.server_num
scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
fl_server_port = args.fl_server_port
start_fl_job_threshold = args.start_fl_job_threshold
fl_name = args.fl_name
fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
client_batch_size = args.client_batch_size
secure_aggregation = args.secure_aggregation
ctx = {
"enable_ps": False,
"server_mode": server_mode,
"ms_role": ms_role,
"worker_num": worker_num,
"server_num": server_num,
"scheduler_ip": scheduler_ip,
"scheduler_port": scheduler_port,
"fl_server_port": fl_server_port,
"start_fl_job_threshold": start_fl_job_threshold,
"fl_name": fl_name,
"fl_iteration_num": fl_iteration_num,
"client_epoch_num": client_epoch_num,
"client_batch_size": client_batch_size,
"secure_aggregation": secure_aggregation
}
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)
context.set_ps_context(**ctx)
if __name__ == "__main__":
epoch = 5
np.random.seed(0)
network = LeNet5(62)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
net_adam_opt = AdamWeightDecayOp(network.trainable_params(), weight_decay=0.1)
net_with_criterion = WithLossCell(network, criterion)
train_network = TrainOneStepCell(net_with_criterion, net_opt)
train_network.set_train()
losses = []
for _ in range(epoch):
data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32))
label = Tensor(np.random.randint(0, 61, (32)).astype(np.int32))
loss = train_network(data, label).asnumpy()
losses.append(loss)
print(losses)