diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc b/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc index d2dd96c7738..b2988ecec0b 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc @@ -71,11 +71,19 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size } bool CipherInit::Check_Parames() { - if (featuremap_ < 1 || secrets_minnums_ < 1 || share_clients_num_need_ < reconstruct_clients_num_need_ || - reconstruct_clients_num_need_ <= secrets_minnums_ || client_num_need_ < share_clients_num_need_) { - MS_LOG(ERROR) << "CIPHER Init Params are illegal."; + if (featuremap_ < 1) { + MS_LOG(ERROR) << "Featuremap size should be positive, but got " << featuremap_; return false; } + + if (share_clients_num_need_ < reconstruct_clients_num_need_) { + MS_LOG(ERROR) + << "reconstruct_clients_num_need (which is reconstruct_secrets_threshold + 1) should not be larger " + "than share_clients_num_need (which is start_fl_job_threshold*share_secrets_ratio), but got they are:" + << reconstruct_clients_num_need_ << ", " << share_clients_num_need_; + return false; + } + return true; } diff --git a/mindspore/ccsrc/fl/server/common.h b/mindspore/ccsrc/fl/server/common.h index c958808ccf3..a53e6697203 100644 --- a/mindspore/ccsrc/fl/server/common.h +++ b/mindspore/ccsrc/fl/server/common.h @@ -62,7 +62,7 @@ struct RoundConfig { struct CipherConfig { float share_secrets_ratio = 1.0; uint64_t cipher_time_window = 300000; - size_t reconstruct_secrets_threshhold = 0; + size_t reconstruct_secrets_threshold = 0; }; using mindspore::kernel::Address; diff --git a/mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.cc index 98c6a756fda..ebf539e1839 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.cc @@ -109,7 +109,6 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient bool ClientListKernel::Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) { std::shared_ptr fbb = std::make_shared(); - // bool response = false; size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); size_t total_duration = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); MS_LOG(INFO) << "Iteration number is " << iter_num << ", ClientListKernel total duration is " << total_duration; @@ -139,7 +138,6 @@ bool ClientListKernel::Launch(const std::vector &inputs, const std:: "GetClientList is nullptr or ClientListRsp builder is nullptr.", client_list, std::to_string(CURRENT_TIME_MILLI.count()), iter_num); } else { - // response = DealClient(iter_num, get_clients_req, fbb); DealClient(iter_num, get_clients_req, fbb); } } diff --git a/mindspore/ccsrc/fl/server/server.cc b/mindspore/ccsrc/fl/server/server.cc index 0ea8924cd72..b7009501373 100644 --- a/mindspore/ccsrc/fl/server/server.cc +++ b/mindspore/ccsrc/fl/server/server.cc @@ -110,6 +110,11 @@ void Server::InitServerContext() { scheduler_port_ = ps::PSContext::instance()->scheduler_port(); worker_num_ = ps::PSContext::instance()->initial_worker_num(); server_num_ = ps::PSContext::instance()->initial_server_num(); + std::string encrypt_type = ps::PSContext::instance()->encrypt_type(); + if (encrypt_type == ps::kPWEncryptType && server_num_ > 1) { + MS_LOG(EXCEPTION) << "Only single server is supported for PW_ENCRYPT now, but got server_num is:." << server_num_; + return; + } return; } @@ -183,7 +188,7 @@ void Server::InitIteration() { cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio; cipher_get_clientlist_cnt_ = rounds_config_[1].threshold_count; cipher_reconstruct_secrets_up_cnt_ = rounds_config_[1].threshold_count; - cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshhold; + cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshold; cipher_time_window_ = cipher_config_.cipher_time_window; MS_LOG(INFO) << "Initializing cipher:"; diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index b10f954222c..7405a720ad1 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -674,9 +674,9 @@ bool StartServerAction(const ResourcePtr &res) { float share_secrets_ratio = ps::PSContext::instance()->share_secrets_ratio(); uint64_t cipher_time_window = ps::PSContext::instance()->cipher_time_window(); - size_t reconstruct_secrets_threshhold = ps::PSContext::instance()->reconstruct_secrets_threshhold(); + size_t reconstruct_secrets_threshold = ps::PSContext::instance()->reconstruct_secrets_threshold(); - fl::server::CipherConfig cipher_config = {share_secrets_ratio, cipher_time_window, reconstruct_secrets_threshhold}; + fl::server::CipherConfig cipher_config = {share_secrets_ratio, cipher_time_window, reconstruct_secrets_threshold}; size_t executor_threshold = 0; if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) { diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index c9bbf7c0f1f..18b5b7a75d1 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -372,9 +372,9 @@ PYBIND11_MODULE(_c_expression, m) { "Set threshold count ratio for share secrets round.") .def("share_secrets_ratio", &PSContext::share_secrets_ratio, "Get threshold count ratio for share secrets round.") .def("set_cipher_time_window", &PSContext::set_cipher_time_window, "Set time window for each cipher round.") - .def("set_reconstruct_secrets_threshhold", &PSContext::set_reconstruct_secrets_threshhold, + .def("set_reconstruct_secrets_threshold", &PSContext::set_reconstruct_secrets_threshold, "Set threshold count for reconstruct secrets round.") - .def("reconstruct_secrets_threshhold", &PSContext::reconstruct_secrets_threshhold, + .def("reconstruct_secrets_threshold", &PSContext::reconstruct_secrets_threshold, "Get threshold count for reconstruct secrets round.") .def("set_fl_name", &PSContext::set_fl_name, "Set federated learning name.") .def("fl_name", &PSContext::fl_name, "Get federated learning name.") diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 9f05f2443a8..b6dfd69789d 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -336,24 +336,36 @@ void PSContext::set_update_model_time_window(uint64_t update_model_time_window) uint64_t PSContext::update_model_time_window() const { return update_model_time_window_; } -void PSContext::set_share_secrets_ratio(float share_secrets_ratio) { share_secrets_ratio_ = share_secrets_ratio; } +void PSContext::set_share_secrets_ratio(float share_secrets_ratio) { + if (share_secrets_ratio > 0 && share_secrets_ratio <= 1) { + share_secrets_ratio_ = share_secrets_ratio; + } else { + MS_LOG(EXCEPTION) << share_secrets_ratio << " is invalid, share_secrets_ratio must be in range of (0, 1]."; + return; + } +} float PSContext::share_secrets_ratio() const { return share_secrets_ratio_; } void PSContext::set_cipher_time_window(uint64_t cipher_time_window) { if (cipher_time_window_ < 0) { - MS_LOG(EXCEPTION) << "cipher_time_window should not be less than 0.."; + MS_LOG(EXCEPTION) << "cipher_time_window should not be less than 0."; + return; } cipher_time_window_ = cipher_time_window; } uint64_t PSContext::cipher_time_window() const { return cipher_time_window_; } -void PSContext::set_reconstruct_secrets_threshhold(uint64_t reconstruct_secrets_threshhold) { - reconstruct_secrets_threshhold_ = reconstruct_secrets_threshhold; +void PSContext::set_reconstruct_secrets_threshold(uint64_t reconstruct_secrets_threshold) { + if (reconstruct_secrets_threshold <= 0) { + MS_LOG(EXCEPTION) << "reconstruct_secrets_threshold should be positive."; + return; + } + reconstruct_secrets_threshold_ = reconstruct_secrets_threshold; } -uint64_t PSContext::reconstruct_secrets_threshhold() const { return reconstruct_secrets_threshhold_; } +uint64_t PSContext::reconstruct_secrets_threshold() const { return reconstruct_secrets_threshold_; } void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; } diff --git a/mindspore/ccsrc/ps/ps_context.h b/mindspore/ccsrc/ps/ps_context.h index 124c6bf6ee2..e9da56ade2b 100644 --- a/mindspore/ccsrc/ps/ps_context.h +++ b/mindspore/ccsrc/ps/ps_context.h @@ -134,8 +134,8 @@ class PSContext { void set_cipher_time_window(uint64_t cipher_time_window); uint64_t cipher_time_window() const; - void set_reconstruct_secrets_threshhold(uint64_t reconstruct_secrets_threshhold); - uint64_t reconstruct_secrets_threshhold() const; + void set_reconstruct_secrets_threshold(uint64_t reconstruct_secrets_threshold); + uint64_t reconstruct_secrets_threshold() const; void set_fl_name(const std::string &fl_name); const std::string &fl_name() const; @@ -202,7 +202,7 @@ class PSContext { update_model_time_window_(3000), share_secrets_ratio_(1.0), cipher_time_window_(300000), - reconstruct_secrets_threshhold_(2000), + reconstruct_secrets_threshold_(2000), fl_iteration_num_(20), client_epoch_num_(25), client_batch_size_(32), @@ -264,7 +264,7 @@ class PSContext { uint64_t cipher_time_window_; // The threshold count of reconstruct secrets round. Used in federated learning for now. - uint64_t reconstruct_secrets_threshhold_; + uint64_t reconstruct_secrets_threshold_; // Iteration number of federeated learning, which is the number of interactions between client and server. uint64_t fl_iteration_num_; diff --git a/mindspore/parallel/_ps_context.py b/mindspore/parallel/_ps_context.py index 6a9f7897a59..c7a0fc89de7 100644 --- a/mindspore/parallel/_ps_context.py +++ b/mindspore/parallel/_ps_context.py @@ -59,7 +59,7 @@ _set_ps_context_func_map = { "update_model_time_window": ps_context().set_update_model_time_window, "share_secrets_ratio": ps_context().set_share_secrets_ratio, "cipher_time_window": ps_context().set_cipher_time_window, - "reconstruct_secrets_threshhold": ps_context().set_reconstruct_secrets_threshhold, + "reconstruct_secrets_threshold": ps_context().set_reconstruct_secrets_threshold, "fl_name": ps_context().set_fl_name, "fl_iteration_num": ps_context().set_fl_iteration_num, "client_epoch_num": ps_context().set_client_epoch_num, @@ -92,7 +92,7 @@ _get_ps_context_func_map = { "update_model_time_window": ps_context().update_model_time_window, "share_secrets_ratio": ps_context().share_secrets_ratio, "cipher_time_window": ps_context().set_cipher_time_window, - "reconstruct_secrets_threshhold": ps_context().reconstruct_secrets_threshhold, + "reconstruct_secrets_threshold": ps_context().reconstruct_secrets_threshold, "fl_name": ps_context().fl_name, "fl_iteration_num": ps_context().fl_iteration_num, "client_epoch_num": ps_context().client_epoch_num, diff --git a/tests/st/fl/albert/cloud_train.py b/tests/st/fl/albert/cloud_train.py index f371c32173b..ba13ba28d4c 100644 --- a/tests/st/fl/albert/cloud_train.py +++ b/tests/st/fl/albert/cloud_train.py @@ -61,8 +61,11 @@ def parse_args(): parser.add_argument("--scheduler_manage_port", type=int, default=11202) parser.add_argument("--dp_eps", type=float, default=50.0) parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold - parser.add_argument("--dp_norm_clip", type=float, default=0.05) + parser.add_argument("--dp_norm_clip", type=float, default=1.0) parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") + parser.add_argument("--share_secrets_ratio", type=float, default=1.0) + parser.add_argument("--cipher_time_window", type=int, default=300000) + parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) return parser.parse_args() @@ -94,6 +97,9 @@ def server_train(args): dp_delta = args.dp_delta dp_norm_clip = args.dp_norm_clip encrypt_type = args.encrypt_type + share_secrets_ratio = args.share_secrets_ratio + cipher_time_window = args.cipher_time_window + reconstruct_secrets_threshold = args.reconstruct_secrets_threshold # Replace some parameters with federated learning parameters. train_cfg.max_global_epoch = fl_iteration_num @@ -119,7 +125,10 @@ def server_train(args): "scheduler_manage_port": scheduler_manage_port, "dp_delta": dp_delta, "dp_norm_clip": dp_norm_clip, - "encrypt_type": encrypt_type + "encrypt_type": encrypt_type, + "share_secrets_ratio": share_secrets_ratio, + "cipher_time_window": cipher_time_window, + "reconstruct_secrets_threshold": reconstruct_secrets_threshold, } if not os.path.exists(output_dir): diff --git a/tests/st/fl/albert/run_hybrid_train_server.py b/tests/st/fl/albert/run_hybrid_train_server.py index e8296797680..b050dda8274 100644 --- a/tests/st/fl/albert/run_hybrid_train_server.py +++ b/tests/st/fl/albert/run_hybrid_train_server.py @@ -35,10 +35,15 @@ parser.add_argument("--client_batch_size", type=int, default=32) parser.add_argument("--client_learning_rate", type=float, default=0.1) # The number of servers that this script will launch. parser.add_argument("--local_server_num", type=int, default=-1) -parser.add_argument("--dp_eps", type=float, default=50.0) -parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold -parser.add_argument("--dp_norm_clip", type=float, default=0.05) parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") +# parameters for encrypt_type='DP_ENCRYPT' +parser.add_argument("--dp_eps", type=float, default=50.0) +parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num +parser.add_argument("--dp_norm_clip", type=float, default=1.0) +# parameters for encrypt_type='PW_ENCRYPT' +parser.add_argument("--share_secrets_ratio", type=float, default=1.0) +parser.add_argument("--cipher_time_window", type=int, default=300000) +parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) args, _ = parser.parse_known_args() device_target = args.device_target @@ -62,6 +67,9 @@ dp_eps = args.dp_eps dp_delta = args.dp_delta dp_norm_clip = args.dp_norm_clip encrypt_type = args.encrypt_type +share_secrets_ratio = args.share_secrets_ratio +cipher_time_window = args.cipher_time_window +reconstruct_secrets_threshold = args.reconstruct_secrets_threshold if local_server_num == -1: local_server_num = server_num @@ -95,6 +103,9 @@ for i in range(local_server_num): cmd_server += " --dp_delta=" + str(dp_delta) cmd_server += " --dp_norm_clip=" + str(dp_norm_clip) cmd_server += " --encrypt_type=" + str(encrypt_type) + cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio) + cmd_server += " --cipher_time_window=" + str(cipher_time_window) + cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold) cmd_server += " > server.log 2>&1 &" import time diff --git a/tests/st/fl/hybrid_lenet/run_hybrid_train_server.py b/tests/st/fl/hybrid_lenet/run_hybrid_train_server.py index 9ea54d3030d..8987e3c54cb 100644 --- a/tests/st/fl/hybrid_lenet/run_hybrid_train_server.py +++ b/tests/st/fl/hybrid_lenet/run_hybrid_train_server.py @@ -43,7 +43,7 @@ parser.add_argument("--dp_norm_clip", type=float, default=1.0) # parameters for encrypt_type='PW_ENCRYPT' parser.add_argument("--share_secrets_ratio", type=float, default=1.0) parser.add_argument("--cipher_time_window", type=int, default=300000) -parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3) +parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) args, _ = parser.parse_known_args() device_target = args.device_target @@ -67,7 +67,7 @@ config_file_path = args.config_file_path encrypt_type = args.encrypt_type share_secrets_ratio = args.share_secrets_ratio cipher_time_window = args.cipher_time_window -reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold +reconstruct_secrets_threshold = args.reconstruct_secrets_threshold dp_eps = args.dp_eps dp_delta = args.dp_delta dp_norm_clip = args.dp_norm_clip @@ -104,7 +104,7 @@ for i in range(local_server_num): cmd_server += " --encrypt_type=" + str(encrypt_type) cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio) cmd_server += " --cipher_time_window=" + str(cipher_time_window) - cmd_server += " --reconstruct_secrets_threshhold=" + str(reconstruct_secrets_threshhold) + cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold) cmd_server += " --dp_eps=" + str(dp_eps) cmd_server += " --dp_delta=" + str(dp_delta) cmd_server += " --dp_norm_clip=" + str(dp_norm_clip) diff --git a/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py b/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py index 188916daf3d..c2f7e5c6c3e 100644 --- a/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py +++ b/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py @@ -53,7 +53,7 @@ parser.add_argument("--dp_norm_clip", type=float, default=1.0) # parameters for encrypt_type='PW_ENCRYPT' parser.add_argument("--share_secrets_ratio", type=float, default=1.0) parser.add_argument("--cipher_time_window", type=int, default=300000) -parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3) +parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) args, _ = parser.parse_known_args() device_target = args.device_target @@ -79,7 +79,7 @@ config_file_path = args.config_file_path encrypt_type = args.encrypt_type share_secrets_ratio = args.share_secrets_ratio cipher_time_window = args.cipher_time_window -reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold +reconstruct_secrets_threshold = args.reconstruct_secrets_threshold dp_eps = args.dp_eps dp_delta = args.dp_delta dp_norm_clip = args.dp_norm_clip @@ -107,7 +107,7 @@ ctx = { "config_file_path": config_file_path, "share_secrets_ratio": share_secrets_ratio, "cipher_time_window": cipher_time_window, - "reconstruct_secrets_threshhold": reconstruct_secrets_threshhold, + "reconstruct_secrets_threshold": reconstruct_secrets_threshold, "dp_eps": dp_eps, "dp_delta": dp_delta, "dp_norm_clip": dp_norm_clip, diff --git a/tests/st/fl/mobile/run_mobile_server.py b/tests/st/fl/mobile/run_mobile_server.py index 26fd7f0cd5b..2e46d56566a 100644 --- a/tests/st/fl/mobile/run_mobile_server.py +++ b/tests/st/fl/mobile/run_mobile_server.py @@ -43,7 +43,7 @@ parser.add_argument("--dp_norm_clip", type=float, default=1.0) # parameters for encrypt_type='PW_ENCRYPT' parser.add_argument("--share_secrets_ratio", type=float, default=1.0) parser.add_argument("--cipher_time_window", type=int, default=300000) -parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3) +parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) if __name__ == "__main__": args, _ = parser.parse_known_args() @@ -60,7 +60,7 @@ if __name__ == "__main__": update_model_time_window = args.update_model_time_window share_secrets_ratio = args.share_secrets_ratio cipher_time_window = args.cipher_time_window - reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold + reconstruct_secrets_threshold = args.reconstruct_secrets_threshold fl_name = args.fl_name fl_iteration_num = args.fl_iteration_num client_epoch_num = args.client_epoch_num @@ -98,7 +98,7 @@ if __name__ == "__main__": cmd_server += " --update_model_time_window=" + str(update_model_time_window) cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio) cmd_server += " --cipher_time_window=" + str(cipher_time_window) - cmd_server += " --reconstruct_secrets_threshhold=" + str(reconstruct_secrets_threshhold) + cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold) cmd_server += " --fl_name=" + fl_name cmd_server += " --fl_iteration_num=" + str(fl_iteration_num) cmd_server += " --config_file_path=" + str(config_file_path) diff --git a/tests/st/fl/mobile/test_mobile_lenet.py b/tests/st/fl/mobile/test_mobile_lenet.py index a6fded1ccec..7916853fa59 100644 --- a/tests/st/fl/mobile/test_mobile_lenet.py +++ b/tests/st/fl/mobile/test_mobile_lenet.py @@ -51,7 +51,7 @@ parser.add_argument("--dp_norm_clip", type=float, default=1.0) # parameters for encrypt_type='PW_ENCRYPT' parser.add_argument("--share_secrets_ratio", type=float, default=1.0) parser.add_argument("--cipher_time_window", type=int, default=300000) -parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3) +parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) args, _ = parser.parse_known_args() device_target = args.device_target @@ -68,7 +68,7 @@ update_model_ratio = args.update_model_ratio update_model_time_window = args.update_model_time_window share_secrets_ratio = args.share_secrets_ratio cipher_time_window = args.cipher_time_window -reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold +reconstruct_secrets_threshold = args.reconstruct_secrets_threshold fl_name = args.fl_name fl_iteration_num = args.fl_iteration_num client_epoch_num = args.client_epoch_num @@ -96,7 +96,7 @@ ctx = { "update_model_time_window": update_model_time_window, "share_secrets_ratio": share_secrets_ratio, "cipher_time_window": cipher_time_window, - "reconstruct_secrets_threshhold": reconstruct_secrets_threshhold, + "reconstruct_secrets_threshold": reconstruct_secrets_threshold, "fl_name": fl_name, "fl_iteration_num": fl_iteration_num, "client_epoch_num": client_epoch_num,