Fix an issue of mindspore federated

This commit is contained in:
jin-xiulang 2021-07-09 15:37:04 +08:00
parent 1f1065eab0
commit 5c02e1dc62
15 changed files with 82 additions and 39 deletions

View File

@ -71,11 +71,19 @@ bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size
}
bool CipherInit::Check_Parames() {
if (featuremap_ < 1 || secrets_minnums_ < 1 || share_clients_num_need_ < reconstruct_clients_num_need_ ||
reconstruct_clients_num_need_ <= secrets_minnums_ || client_num_need_ < share_clients_num_need_) {
MS_LOG(ERROR) << "CIPHER Init Params are illegal.";
if (featuremap_ < 1) {
MS_LOG(ERROR) << "Featuremap size should be positive, but got " << featuremap_;
return false;
}
if (share_clients_num_need_ < reconstruct_clients_num_need_) {
MS_LOG(ERROR)
<< "reconstruct_clients_num_need (which is reconstruct_secrets_threshold + 1) should not be larger "
"than share_clients_num_need (which is start_fl_job_threshold*share_secrets_ratio), but got they are:"
<< reconstruct_clients_num_need_ << ", " << share_clients_num_need_;
return false;
}
return true;
}

View File

@ -62,7 +62,7 @@ struct RoundConfig {
struct CipherConfig {
float share_secrets_ratio = 1.0;
uint64_t cipher_time_window = 300000;
size_t reconstruct_secrets_threshhold = 0;
size_t reconstruct_secrets_threshold = 0;
};
using mindspore::kernel::Address;

View File

@ -109,7 +109,6 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient
bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
// bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
MS_LOG(INFO) << "Iteration number is " << iter_num << ", ClientListKernel total duration is " << total_duration;
@ -139,7 +138,6 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
"GetClientList is nullptr or ClientListRsp builder is nullptr.", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
// response = DealClient(iter_num, get_clients_req, fbb);
DealClient(iter_num, get_clients_req, fbb);
}
}

View File

@ -110,6 +110,11 @@ void Server::InitServerContext() {
scheduler_port_ = ps::PSContext::instance()->scheduler_port();
worker_num_ = ps::PSContext::instance()->initial_worker_num();
server_num_ = ps::PSContext::instance()->initial_server_num();
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
if (encrypt_type == ps::kPWEncryptType && server_num_ > 1) {
MS_LOG(EXCEPTION) << "Only single server is supported for PW_ENCRYPT now, but got server_num is:." << server_num_;
return;
}
return;
}
@ -183,7 +188,7 @@ void Server::InitIteration() {
cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio;
cipher_get_clientlist_cnt_ = rounds_config_[1].threshold_count;
cipher_reconstruct_secrets_up_cnt_ = rounds_config_[1].threshold_count;
cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshhold;
cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshold;
cipher_time_window_ = cipher_config_.cipher_time_window;
MS_LOG(INFO) << "Initializing cipher:";

View File

@ -674,9 +674,9 @@ bool StartServerAction(const ResourcePtr &res) {
float share_secrets_ratio = ps::PSContext::instance()->share_secrets_ratio();
uint64_t cipher_time_window = ps::PSContext::instance()->cipher_time_window();
size_t reconstruct_secrets_threshhold = ps::PSContext::instance()->reconstruct_secrets_threshhold();
size_t reconstruct_secrets_threshold = ps::PSContext::instance()->reconstruct_secrets_threshold();
fl::server::CipherConfig cipher_config = {share_secrets_ratio, cipher_time_window, reconstruct_secrets_threshhold};
fl::server::CipherConfig cipher_config = {share_secrets_ratio, cipher_time_window, reconstruct_secrets_threshold};
size_t executor_threshold = 0;
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {

View File

@ -372,9 +372,9 @@ PYBIND11_MODULE(_c_expression, m) {
"Set threshold count ratio for share secrets round.")
.def("share_secrets_ratio", &PSContext::share_secrets_ratio, "Get threshold count ratio for share secrets round.")
.def("set_cipher_time_window", &PSContext::set_cipher_time_window, "Set time window for each cipher round.")
.def("set_reconstruct_secrets_threshhold", &PSContext::set_reconstruct_secrets_threshhold,
.def("set_reconstruct_secrets_threshold", &PSContext::set_reconstruct_secrets_threshold,
"Set threshold count for reconstruct secrets round.")
.def("reconstruct_secrets_threshhold", &PSContext::reconstruct_secrets_threshhold,
.def("reconstruct_secrets_threshold", &PSContext::reconstruct_secrets_threshold,
"Get threshold count for reconstruct secrets round.")
.def("set_fl_name", &PSContext::set_fl_name, "Set federated learning name.")
.def("fl_name", &PSContext::fl_name, "Get federated learning name.")

View File

@ -336,24 +336,36 @@ void PSContext::set_update_model_time_window(uint64_t update_model_time_window)
uint64_t PSContext::update_model_time_window() const { return update_model_time_window_; }
void PSContext::set_share_secrets_ratio(float share_secrets_ratio) { share_secrets_ratio_ = share_secrets_ratio; }
void PSContext::set_share_secrets_ratio(float share_secrets_ratio) {
if (share_secrets_ratio > 0 && share_secrets_ratio <= 1) {
share_secrets_ratio_ = share_secrets_ratio;
} else {
MS_LOG(EXCEPTION) << share_secrets_ratio << " is invalid, share_secrets_ratio must be in range of (0, 1].";
return;
}
}
float PSContext::share_secrets_ratio() const { return share_secrets_ratio_; }
void PSContext::set_cipher_time_window(uint64_t cipher_time_window) {
if (cipher_time_window_ < 0) {
MS_LOG(EXCEPTION) << "cipher_time_window should not be less than 0..";
MS_LOG(EXCEPTION) << "cipher_time_window should not be less than 0.";
return;
}
cipher_time_window_ = cipher_time_window;
}
uint64_t PSContext::cipher_time_window() const { return cipher_time_window_; }
void PSContext::set_reconstruct_secrets_threshhold(uint64_t reconstruct_secrets_threshhold) {
reconstruct_secrets_threshhold_ = reconstruct_secrets_threshhold;
void PSContext::set_reconstruct_secrets_threshold(uint64_t reconstruct_secrets_threshold) {
if (reconstruct_secrets_threshold <= 0) {
MS_LOG(EXCEPTION) << "reconstruct_secrets_threshold should be positive.";
return;
}
reconstruct_secrets_threshold_ = reconstruct_secrets_threshold;
}
uint64_t PSContext::reconstruct_secrets_threshhold() const { return reconstruct_secrets_threshhold_; }
uint64_t PSContext::reconstruct_secrets_threshold() const { return reconstruct_secrets_threshold_; }
void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }

View File

@ -134,8 +134,8 @@ class PSContext {
void set_cipher_time_window(uint64_t cipher_time_window);
uint64_t cipher_time_window() const;
void set_reconstruct_secrets_threshhold(uint64_t reconstruct_secrets_threshhold);
uint64_t reconstruct_secrets_threshhold() const;
void set_reconstruct_secrets_threshold(uint64_t reconstruct_secrets_threshold);
uint64_t reconstruct_secrets_threshold() const;
void set_fl_name(const std::string &fl_name);
const std::string &fl_name() const;
@ -202,7 +202,7 @@ class PSContext {
update_model_time_window_(3000),
share_secrets_ratio_(1.0),
cipher_time_window_(300000),
reconstruct_secrets_threshhold_(2000),
reconstruct_secrets_threshold_(2000),
fl_iteration_num_(20),
client_epoch_num_(25),
client_batch_size_(32),
@ -264,7 +264,7 @@ class PSContext {
uint64_t cipher_time_window_;
// The threshold count of reconstruct secrets round. Used in federated learning for now.
uint64_t reconstruct_secrets_threshhold_;
uint64_t reconstruct_secrets_threshold_;
// Iteration number of federeated learning, which is the number of interactions between client and server.
uint64_t fl_iteration_num_;

View File

@ -59,7 +59,7 @@ _set_ps_context_func_map = {
"update_model_time_window": ps_context().set_update_model_time_window,
"share_secrets_ratio": ps_context().set_share_secrets_ratio,
"cipher_time_window": ps_context().set_cipher_time_window,
"reconstruct_secrets_threshhold": ps_context().set_reconstruct_secrets_threshhold,
"reconstruct_secrets_threshold": ps_context().set_reconstruct_secrets_threshold,
"fl_name": ps_context().set_fl_name,
"fl_iteration_num": ps_context().set_fl_iteration_num,
"client_epoch_num": ps_context().set_client_epoch_num,
@ -92,7 +92,7 @@ _get_ps_context_func_map = {
"update_model_time_window": ps_context().update_model_time_window,
"share_secrets_ratio": ps_context().share_secrets_ratio,
"cipher_time_window": ps_context().set_cipher_time_window,
"reconstruct_secrets_threshhold": ps_context().reconstruct_secrets_threshhold,
"reconstruct_secrets_threshold": ps_context().reconstruct_secrets_threshold,
"fl_name": ps_context().fl_name,
"fl_iteration_num": ps_context().fl_iteration_num,
"client_epoch_num": ps_context().client_epoch_num,

View File

@ -61,8 +61,11 @@ def parse_args():
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold
parser.add_argument("--dp_norm_clip", type=float, default=0.05)
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
return parser.parse_args()
@ -94,6 +97,9 @@ def server_train(args):
dp_delta = args.dp_delta
dp_norm_clip = args.dp_norm_clip
encrypt_type = args.encrypt_type
share_secrets_ratio = args.share_secrets_ratio
cipher_time_window = args.cipher_time_window
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
# Replace some parameters with federated learning parameters.
train_cfg.max_global_epoch = fl_iteration_num
@ -119,7 +125,10 @@ def server_train(args):
"scheduler_manage_port": scheduler_manage_port,
"dp_delta": dp_delta,
"dp_norm_clip": dp_norm_clip,
"encrypt_type": encrypt_type
"encrypt_type": encrypt_type,
"share_secrets_ratio": share_secrets_ratio,
"cipher_time_window": cipher_time_window,
"reconstruct_secrets_threshold": reconstruct_secrets_threshold,
}
if not os.path.exists(output_dir):

View File

@ -35,10 +35,15 @@ parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
# The number of servers that this script will launch.
parser.add_argument("--local_server_num", type=int, default=-1)
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold
parser.add_argument("--dp_norm_clip", type=float, default=0.05)
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
# parameters for encrypt_type='DP_ENCRYPT'
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
# parameters for encrypt_type='PW_ENCRYPT'
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -62,6 +67,9 @@ dp_eps = args.dp_eps
dp_delta = args.dp_delta
dp_norm_clip = args.dp_norm_clip
encrypt_type = args.encrypt_type
share_secrets_ratio = args.share_secrets_ratio
cipher_time_window = args.cipher_time_window
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
if local_server_num == -1:
local_server_num = server_num
@ -95,6 +103,9 @@ for i in range(local_server_num):
cmd_server += " --dp_delta=" + str(dp_delta)
cmd_server += " --dp_norm_clip=" + str(dp_norm_clip)
cmd_server += " --encrypt_type=" + str(encrypt_type)
cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio)
cmd_server += " --cipher_time_window=" + str(cipher_time_window)
cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold)
cmd_server += " > server.log 2>&1 &"
import time

View File

@ -43,7 +43,7 @@ parser.add_argument("--dp_norm_clip", type=float, default=1.0)
# parameters for encrypt_type='PW_ENCRYPT'
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3)
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -67,7 +67,7 @@ config_file_path = args.config_file_path
encrypt_type = args.encrypt_type
share_secrets_ratio = args.share_secrets_ratio
cipher_time_window = args.cipher_time_window
reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
dp_eps = args.dp_eps
dp_delta = args.dp_delta
dp_norm_clip = args.dp_norm_clip
@ -104,7 +104,7 @@ for i in range(local_server_num):
cmd_server += " --encrypt_type=" + str(encrypt_type)
cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio)
cmd_server += " --cipher_time_window=" + str(cipher_time_window)
cmd_server += " --reconstruct_secrets_threshhold=" + str(reconstruct_secrets_threshhold)
cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold)
cmd_server += " --dp_eps=" + str(dp_eps)
cmd_server += " --dp_delta=" + str(dp_delta)
cmd_server += " --dp_norm_clip=" + str(dp_norm_clip)

View File

@ -53,7 +53,7 @@ parser.add_argument("--dp_norm_clip", type=float, default=1.0)
# parameters for encrypt_type='PW_ENCRYPT'
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3)
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -79,7 +79,7 @@ config_file_path = args.config_file_path
encrypt_type = args.encrypt_type
share_secrets_ratio = args.share_secrets_ratio
cipher_time_window = args.cipher_time_window
reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
dp_eps = args.dp_eps
dp_delta = args.dp_delta
dp_norm_clip = args.dp_norm_clip
@ -107,7 +107,7 @@ ctx = {
"config_file_path": config_file_path,
"share_secrets_ratio": share_secrets_ratio,
"cipher_time_window": cipher_time_window,
"reconstruct_secrets_threshhold": reconstruct_secrets_threshhold,
"reconstruct_secrets_threshold": reconstruct_secrets_threshold,
"dp_eps": dp_eps,
"dp_delta": dp_delta,
"dp_norm_clip": dp_norm_clip,

View File

@ -43,7 +43,7 @@ parser.add_argument("--dp_norm_clip", type=float, default=1.0)
# parameters for encrypt_type='PW_ENCRYPT'
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3)
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
if __name__ == "__main__":
args, _ = parser.parse_known_args()
@ -60,7 +60,7 @@ if __name__ == "__main__":
update_model_time_window = args.update_model_time_window
share_secrets_ratio = args.share_secrets_ratio
cipher_time_window = args.cipher_time_window
reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
fl_name = args.fl_name
fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
@ -98,7 +98,7 @@ if __name__ == "__main__":
cmd_server += " --update_model_time_window=" + str(update_model_time_window)
cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio)
cmd_server += " --cipher_time_window=" + str(cipher_time_window)
cmd_server += " --reconstruct_secrets_threshhold=" + str(reconstruct_secrets_threshhold)
cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold)
cmd_server += " --fl_name=" + fl_name
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)
cmd_server += " --config_file_path=" + str(config_file_path)

View File

@ -51,7 +51,7 @@ parser.add_argument("--dp_norm_clip", type=float, default=1.0)
# parameters for encrypt_type='PW_ENCRYPT'
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3)
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -68,7 +68,7 @@ update_model_ratio = args.update_model_ratio
update_model_time_window = args.update_model_time_window
share_secrets_ratio = args.share_secrets_ratio
cipher_time_window = args.cipher_time_window
reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
fl_name = args.fl_name
fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
@ -96,7 +96,7 @@ ctx = {
"update_model_time_window": update_model_time_window,
"share_secrets_ratio": share_secrets_ratio,
"cipher_time_window": cipher_time_window,
"reconstruct_secrets_threshhold": reconstruct_secrets_threshhold,
"reconstruct_secrets_threshold": reconstruct_secrets_threshold,
"fl_name": fl_name,
"fl_iteration_num": fl_iteration_num,
"client_epoch_num": client_epoch_num,