From e227c8413f633dcbaef59387a2dd6538674faf81 Mon Sep 17 00:00:00 2001 From: emmmmtang Date: Fri, 14 Jan 2022 18:52:59 +0800 Subject: [PATCH] signds --- .../ccsrc/fl/armour/cipher/cipher_init.cc | 16 +- .../fl/armour/cipher/cipher_meta_storage.h | 5 + .../kernel/round/start_fl_job_kernel.cc | 10 +- mindspore/ccsrc/fl/server/server.cc | 10 + mindspore/ccsrc/pipeline/jit/init.cc | 14 +- mindspore/ccsrc/ps/ps_context.cc | 87 ++++- mindspore/ccsrc/ps/ps_context.h | 36 +++ .../com/mindspore/flclient/EncryptLevel.java | 1 + .../com/mindspore/flclient/FLLiteClient.java | 45 ++- .../mindspore/flclient/LocalFLParameter.java | 3 +- .../mindspore/flclient/SecureProtocol.java | 306 ++++++++++++++++++ .../com/mindspore/flclient/UpdateModel.java | 12 + mindspore/python/mindspore/context.py | 11 +- .../python/mindspore/parallel/_ps_context.py | 14 +- mindspore/schema/cipher.fbs | 19 +- tests/st/fl/albert/cloud_train.py | 18 +- tests/st/fl/albert/run_hybrid_train_server.py | 16 + .../cloud/run_lenet_server.py | 16 + .../fl/cross_device_lenet/cloud/test_lenet.py | 18 +- .../hybrid_lenet/run_hybrid_train_server.py | 16 + .../hybrid_lenet/test_hybrid_train_lenet.py | 18 +- tests/st/fl/mobile/run_mobile_server.py | 16 + tests/st/fl/mobile/test_mobile_lenet.py | 18 +- 23 files changed, 688 insertions(+), 37 deletions(-) diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc b/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc index 6cf41a3c35a..4c863c8ec23 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc @@ -55,11 +55,21 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size publicparam_.dp_delta = param.dp_delta; publicparam_.dp_norm_clip = param.dp_norm_clip; publicparam_.encrypt_type = param.encrypt_type; + publicparam_.sign_k = param.sign_k; + publicparam_.sign_eps = param.sign_eps; + publicparam_.sign_thr_ratio = param.sign_thr_ratio; + publicparam_.sign_global_lr = param.sign_global_lr; + publicparam_.sign_dim_out = param.sign_dim_out; if (param.encrypt_type == mindspore::ps::kDPEncryptType) { - MS_LOG(INFO) << "DP parameters init, dp_eps: " << param.dp_eps; - MS_LOG(INFO) << "DP parameters init, dp_delta: " << param.dp_delta; - MS_LOG(INFO) << "DP parameters init, dp_norm_clip: " << param.dp_norm_clip; + MS_LOG(INFO) << "DP parameters init, dp_eps: " << param.dp_eps << ", dp_delta: " << param.dp_delta + << ", dp_norm_clip: " << param.dp_norm_clip; + } + + if (param.encrypt_type == mindspore::ps::kDSEncryptType) { + MS_LOG(INFO) << "Sign parameters init, sign_k: " << param.sign_k << ", sign_eps: " << param.sign_eps + << ", sign_thr_ratio: " << param.sign_thr_ratio << ", sign_global_lr: " << param.sign_global_lr + << ", sign_dim_out: " << param.sign_dim_out; } if (param.encrypt_type == mindspore::ps::kPWEncryptType) { diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.h b/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.h index 7d1811b16d9..71235050d33 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.h +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.h @@ -68,6 +68,11 @@ struct CipherPublicPara { float dp_delta; float dp_norm_clip; string encrypt_type; + float sign_k; + float sign_eps; + float sign_thr_ratio; + float sign_global_lr; + int sign_dim_out; }; class CipherMetaStorage { diff --git a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc index ba071778038..2139dbc477b 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc @@ -330,9 +330,17 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr &fbb, float dp_delta = param->dp_delta; float dp_norm_clip = param->dp_norm_clip; auto encrypt_type = fbb->CreateString(ps::PSContext::instance()->encrypt_type()); + float sign_k = param->sign_k; + float sign_eps = param->sign_eps; + float sign_thr_ratio = param->sign_thr_ratio; + float sign_global_lr = param->sign_global_lr; + int sign_dim_out = param->sign_dim_out; + auto pw_params = schema::CreatePWParams(*fbb.get(), t, p, g, prime); + auto dp_params = schema::CreateDPParams(*fbb.get(), dp_eps, dp_delta, dp_norm_clip); + auto ds_params = schema::CreateDSParams(*fbb.get(), sign_k, sign_eps, sign_thr_ratio, sign_global_lr, sign_dim_out); auto cipher_public_params = - schema::CreateCipherPublicParams(*fbb.get(), t, p, g, prime, dp_eps, dp_delta, dp_norm_clip, encrypt_type); + schema::CreateCipherPublicParams(*fbb.get(), encrypt_type, pw_params, dp_params, ds_params); #endif schema::FLPlanBuilder fl_plan_builder(*(fbb.get())); diff --git a/mindspore/ccsrc/fl/server/server.cc b/mindspore/ccsrc/fl/server/server.cc index deb53a930d0..ef88ff656b5 100644 --- a/mindspore/ccsrc/fl/server/server.cc +++ b/mindspore/ccsrc/fl/server/server.cc @@ -258,6 +258,11 @@ void Server::InitCipher() { float dp_delta = ps::PSContext::instance()->dp_delta(); float dp_norm_clip = ps::PSContext::instance()->dp_norm_clip(); std::string encrypt_type = ps::PSContext::instance()->encrypt_type(); + float sign_k = ps::PSContext::instance()->sign_k(); + float sign_eps = ps::PSContext::instance()->sign_eps(); + float sign_thr_ratio = ps::PSContext::instance()->sign_thr_ratio(); + float sign_global_lr = ps::PSContext::instance()->sign_global_lr(); + int sign_dim_out = ps::PSContext::instance()->sign_dim_out(); mindspore::armour::CipherPublicPara param; param.g = cipher_g; @@ -270,6 +275,11 @@ void Server::InitCipher() { param.dp_eps = dp_eps; param.dp_norm_clip = dp_norm_clip; param.encrypt_type = encrypt_type; + param.sign_k = sign_k; + param.sign_eps = sign_eps; + param.sign_thr_ratio = sign_thr_ratio; + param.sign_global_lr = sign_global_lr; + param.sign_dim_out = sign_dim_out; BIGNUM *prim = BN_new(); if (prim == NULL) { diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index b3faf404d1f..6ff482acc83 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -468,12 +468,18 @@ PYBIND11_MODULE(_c_expression, m) { "Set configuration files required by the communication layer.") .def("config_file_path", &PSContext::config_file_path, "Get configuration files required by the communication layer.") - .def("set_dp_eps", &PSContext::set_dp_eps, "Set dp epsilon for federated learning secure aggregation.") - .def("set_dp_delta", &PSContext::set_dp_delta, "Set dp delta for federated learning secure aggregation.") - .def("set_dp_norm_clip", &PSContext::set_dp_norm_clip, - "Set dp norm clip for federated learning secure aggregation.") .def("set_encrypt_type", &PSContext::set_encrypt_type, "Set encrypt type for federated learning secure aggregation.") + .def("set_sign_k", &PSContext::set_sign_k, "Set sign k for federated learning SignDS.") + .def("sign_k", &PSContext::sign_k, "Get sign k for federated learning SignDS.") + .def("set_sign_eps", &PSContext::set_sign_eps, "Set sign eps for federated learning SignDS.") + .def("sign_eps", &PSContext::sign_eps, "Get sign eps for federated learning SignDS.") + .def("set_sign_thr_ratio", &PSContext::set_sign_thr_ratio, "Set sign thr ratio for federated learning SignDS.") + .def("sign_thr_ratio", &PSContext::sign_thr_ratio, "Get sign thr ratio for federated learning SignDS.") + .def("set_sign_global_lr", &PSContext::set_sign_global_lr, "Set sign global lr for federated learning SignDS.") + .def("sign_global_lr", &PSContext::sign_global_lr, "Get sign global lr for federated learning SignDS.") + .def("set_sign_dim_out", &PSContext::set_sign_dim_out, "Set sign dim out for federated learning SignDS.") + .def("sign_dim_out", &PSContext::sign_dim_out, "Get sign dim out for federated learning SignDS.") .def("set_http_url_prefix", &PSContext::set_http_url_prefix, "Set http url prefix for http communication.") .def("http_url_prefix", &PSContext::http_url_prefix, "http url prefix for http communication.") .def("set_global_iteration_time_window", &PSContext::set_global_iteration_time_window, diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 3085a298ea7..0c8151ca4b6 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -220,12 +220,14 @@ const std::string &PSContext::server_mode() const { return server_mode_; } void PSContext::set_encrypt_type(const std::string &encrypt_type) { if (encrypt_type != kNotEncryptType && encrypt_type != kDPEncryptType && encrypt_type != kPWEncryptType && - encrypt_type != kStablePWEncryptType) { - MS_LOG(EXCEPTION) << encrypt_type << " is invalid. Encrypt type must be " << kNotEncryptType << " or " - << kDPEncryptType << " or " << kPWEncryptType << " or " << kStablePWEncryptType; - return; + encrypt_type != kStablePWEncryptType && encrypt_type != kDSEncryptType) { + MS_LOG(WARNING) << encrypt_type << " is invalid. Encrypt type must be " << kNotEncryptType << " or " + << kDPEncryptType << " or " << kPWEncryptType << " or " << kStablePWEncryptType << " or " + << kDSEncryptType << ", DP scheme is used by default."; + encrypt_type_ = kDPEncryptType; + } else { + encrypt_type_ = encrypt_type; } - encrypt_type_ = encrypt_type; } const std::string &PSContext::encrypt_type() const { return encrypt_type_; } @@ -234,8 +236,9 @@ void PSContext::set_dp_eps(float dp_eps) { if (dp_eps > 0) { dp_eps_ = dp_eps; } else { - MS_LOG(EXCEPTION) << dp_eps << " is invalid, dp_eps must be larger than 0."; - return; + MS_LOG(WARNING) << dp_eps << " is invalid, dp_eps must be larger than 0, 50 is used by default."; + float dp_eps_default = 50; + dp_eps_ = dp_eps_default; } } @@ -245,8 +248,9 @@ void PSContext::set_dp_delta(float dp_delta) { if (dp_delta > 0 && dp_delta < 1) { dp_delta_ = dp_delta; } else { - MS_LOG(EXCEPTION) << dp_delta << " is invalid, dp_delta must be in range of (0, 1)."; - return; + MS_LOG(WARNING) << dp_delta << " is invalid, dp_delta must be in range of (0, 1), 0.01 is used by default."; + float dp_delta_default = 0.01; + dp_delta_ = dp_delta_default; } } float PSContext::dp_delta() const { return dp_delta_; } @@ -255,12 +259,73 @@ void PSContext::set_dp_norm_clip(float dp_norm_clip) { if (dp_norm_clip > 0) { dp_norm_clip_ = dp_norm_clip; } else { - MS_LOG(EXCEPTION) << dp_norm_clip << " is invalid, dp_norm_clip must be larger than 0."; - return; + MS_LOG(WARNING) << dp_norm_clip << " is invalid, dp_norm_clip must be larger than 0, 1 is used by default."; + float dp_norm_clip_default = 1; + dp_norm_clip_ = dp_norm_clip_default; } } float PSContext::dp_norm_clip() const { return dp_norm_clip_; } +void PSContext::set_sign_k(float sign_k) { + float sign_k_upper = 0.25; + if (sign_k > 0 && sign_k <= sign_k_upper) { + sign_k_ = sign_k; + } else { + MS_LOG(WARNING) << sign_k << " is invalid, sign_k must be in range of (0, 0.25], 0.01 is used by default."; + float sign_k_default = 0.01; + sign_k_ = sign_k_default; + } +} +float PSContext::sign_k() const { return sign_k_; } + +void PSContext::set_sign_eps(float sign_eps) { + float sign_eps_upper = 100; + if (sign_eps > 0 && sign_eps <= sign_eps_upper) { + sign_eps_ = sign_eps; + } else { + MS_LOG(WARNING) << sign_eps << " is invalid, sign_eps must be in range of (0, 100], 100 is used by default."; + float sign_eps_default = 100; + sign_eps_ = sign_eps_default; + } +} +float PSContext::sign_eps() const { return sign_eps_; } + +void PSContext::set_sign_thr_ratio(float sign_thr_ratio) { + float sign_thr_ratio_bound = 0.5; + if (sign_thr_ratio >= sign_thr_ratio_bound && sign_thr_ratio <= 1) { + sign_thr_ratio_ = sign_thr_ratio; + } else { + MS_LOG(WARNING) << sign_thr_ratio + << " is invalid, sign_thr_ratio must be in range of [0.5, 1], 0.6 is used by default."; + float sign_thr_ratio_default = 0.6; + sign_thr_ratio_ = sign_thr_ratio_default; + } +} +float PSContext::sign_thr_ratio() const { return sign_thr_ratio_; } + +void PSContext::set_sign_global_lr(float sign_global_lr) { + if (sign_global_lr > 0) { + sign_global_lr_ = sign_global_lr; + } else { + MS_LOG(WARNING) << sign_global_lr << " is invalid, sign_global_lr must be larger than 0, 1 is used by default."; + float sign_global_lr_default = 1; + sign_global_lr_ = sign_global_lr_default; + } +} +float PSContext::sign_global_lr() const { return sign_global_lr_; } + +void PSContext::set_sign_dim_out(int sign_dim_out) { + int sign_dim_out_upper = 50; + if (sign_dim_out >= 0 && sign_dim_out <= sign_dim_out_upper) { + sign_dim_out_ = sign_dim_out; + } else { + MS_LOG(WARNING) << sign_dim_out << " is invalid, sign_dim_out must be in range of [0, 50], 0 is used by default."; + float sign_dim_out_default = 0; + sign_dim_out_ = sign_dim_out_default; + } +} +int PSContext::sign_dim_out() const { return sign_dim_out_; } + void PSContext::set_ms_role(const std::string &role) { if (role != kEnvRoleOfWorker && role != kEnvRoleOfServer && role != kEnvRoleOfScheduler) { MS_LOG(EXCEPTION) << "ms_role " << role << " is invalid."; diff --git a/mindspore/ccsrc/ps/ps_context.h b/mindspore/ccsrc/ps/ps_context.h index a0729054666..c1a3132f8be 100644 --- a/mindspore/ccsrc/ps/ps_context.h +++ b/mindspore/ccsrc/ps/ps_context.h @@ -39,6 +39,7 @@ constexpr char kDPEncryptType[] = "DP_ENCRYPT"; constexpr char kPWEncryptType[] = "PW_ENCRYPT"; constexpr char kStablePWEncryptType[] = "STABLE_PW_ENCRYPT"; constexpr char kNotEncryptType[] = "NOT_ENCRYPT"; +constexpr char kDSEncryptType[] = "SIGNDS"; // Use binary data to represent federated learning server's context so that we can judge which round resets the // iteration. From right to left, each bit stands for: @@ -173,6 +174,21 @@ class PSContext { void set_dp_norm_clip(float dp_norm_clip); float dp_norm_clip() const; + void set_sign_k(float sign_k); + float sign_k() const; + + void set_sign_eps(float sign_eps); + float sign_eps() const; + + void set_sign_thr_ratio(float sign_thr_ratio); + float sign_thr_ratio() const; + + void set_sign_global_lr(float sign_global_lr); + float sign_global_lr() const; + + void set_sign_dim_out(int sign_dim_out); + int sign_dim_out() const; + core::ClusterConfig &cluster_config(); void set_root_first_ca_path(const std::string &root_first_ca_path); @@ -257,6 +273,11 @@ class PSContext { dp_delta_(0.01), dp_norm_clip_(1.0), encrypt_type_(kNotEncryptType), + sign_k_(0.01), + sign_eps_(100), + sign_thr_ratio_(0.6), + sign_global_lr_(1), + sign_dim_out_(0), enable_ssl_(false), client_password_(""), server_password_(""), @@ -368,6 +389,21 @@ class PSContext { // Secure mechanism for federated learning. Used in federated learning for now. std::string encrypt_type_; + // Top-k of SignDS mechanism. + float sign_k_; + + // Privacy budget epsilon of SignDS mechanism. + float sign_eps_; + + // The threshold for the expected ratio of topk dimensions in the output of SignDS mechanism. + float sign_thr_ratio_; + + // Global learning rate of SignDS mechanism. + float sign_global_lr_; + + // The number of output dimension of SignDS mechanism. + int sign_dim_out_; + // Whether to enable ssl for network communication. bool enable_ssl_; // Password used to decode p12 file. diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/EncryptLevel.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/EncryptLevel.java index 73ea77b4033..5c243a7f8c5 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/EncryptLevel.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/EncryptLevel.java @@ -24,5 +24,6 @@ package com.mindspore.flclient; public enum EncryptLevel { PW_ENCRYPT, DP_ENCRYPT, + SIGNDS, NOT_ENCRYPT } \ No newline at end of file diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java index 574c8654648..9f9c483723f 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java @@ -79,6 +79,11 @@ public class FLLiteClient { private String nextRequestTime; private Client client; private Map oldFeatureMap; + private float signK = 0.01f; + private float signEps = 100; + private float signThrRatio = 0.6f; + private float signGlobalLr = 1f; + private int signDimOut = 0; /** * Defining a constructor of teh class FLLiteClient. @@ -124,11 +129,11 @@ public class FLLiteClient { } switch (localFLParameter.getEncryptLevel()) { case PW_ENCRYPT: - minSecretNum = cipherPublicParams.t(); - int primeLength = cipherPublicParams.primeLength(); + minSecretNum = cipherPublicParams.pwParams().t(); + int primeLength = cipherPublicParams.pwParams().primeLength(); prime = new byte[primeLength]; for (int i = 0; i < primeLength; i++) { - prime[i] = (byte) cipherPublicParams.prime(i); + prime[i] = (byte) cipherPublicParams.pwParams().prime(i); } LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + minSecretNum)); if (minSecretNum <= 0) { @@ -138,14 +143,26 @@ public class FLLiteClient { } break; case DP_ENCRYPT: - dpEps = cipherPublicParams.dpEps(); - dpDelta = cipherPublicParams.dpDelta(); - dpNormClipFactor = cipherPublicParams.dpNormClip(); + dpEps = cipherPublicParams.dpParams().dpEps(); + dpDelta = cipherPublicParams.dpParams().dpDelta(); + dpNormClipFactor = cipherPublicParams.dpParams().dpNormClip(); LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + dpEps)); LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + dpDelta)); LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + dpNormClipFactor)); break; + case SIGNDS: + signK = cipherPublicParams.dsParams().signK(); + signEps = cipherPublicParams.dsParams().signEps(); + signThrRatio = cipherPublicParams.dsParams().signThrRatio(); + signGlobalLr = cipherPublicParams.dsParams().signGlobalLr(); + signDimOut = cipherPublicParams.dsParams().signDimOut(); + LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + signK)); + LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + signEps)); + LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + signThrRatio)); + LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + signGlobalLr)); + LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + signDimOut)); + break; default: LOGGER.info(Common.addTag("[startFLJob] NOT_ENCRYPT, do not set parameter for Encrypt")); } @@ -418,7 +435,7 @@ public class FLLiteClient { } LOGGER.info(Common.addTag("[getModel] get response from server ok!")); } catch (IOException e) { - failed("[getModel] un sloved error code: catch IOException: " + e.getMessage(), ResponseCode.RequestError); + failed("[getModel] unsolved error code: catch IOException: " + e.getMessage(), ResponseCode.RequestError); } return status; } @@ -511,12 +528,24 @@ public class FLLiteClient { curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClipAdapt, oldFeatureMap); retCode = ResponseCode.SUCCEED; if (curStatus != FLClientStatus.SUCCESS) { - LOGGER.info(Common.addTag("---Differential privacy init failed---")); + LOGGER.severe(Common.addTag("---Differential privacy init failed---")); retCode = ResponseCode.RequestError; return FLClientStatus.FAILED; } LOGGER.info(Common.addTag("[Encrypt] set parameters for DP_ENCRYPT!")); return FLClientStatus.SUCCESS; + case SIGNDS: + // get the feature map before train + oldFeatureMap = getFeatureMap(); + curStatus = secureProtocol.setDSParameter(signK, signEps, signThrRatio, signGlobalLr, signDimOut, oldFeatureMap); + retCode = ResponseCode.SUCCEED; + if (curStatus != FLClientStatus.SUCCESS) { + LOGGER.severe(Common.addTag("---SignDS init failed---")); + retCode = ResponseCode.RequestError; + return FLClientStatus.FAILED; + } + LOGGER.info(Common.addTag("[Encrypt] set parameters for SignDS!")); + return FLClientStatus.SUCCESS; case NOT_ENCRYPT: retCode = ResponseCode.SUCCEED; LOGGER.info(Common.addTag("[Encrypt] don't mask model")); diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java index 59a6cbc4d3b..f787db7292d 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java @@ -163,9 +163,10 @@ public class LocalFLParameter { } if ((!EncryptLevel.DP_ENCRYPT.toString().equals(encryptLevel)) && (!EncryptLevel.NOT_ENCRYPT.toString().equals(encryptLevel)) && + (!EncryptLevel.SIGNDS.toString().equals(encryptLevel)) && (!EncryptLevel.PW_ENCRYPT.toString().equals(encryptLevel))) { LOGGER.severe(Common.addTag("[localFLParameter] the parameter of is " + encryptLevel + " ," + - " it must be DP_ENCRYPT or NOT_ENCRYPT or PW_ENCRYPT, please check it before setting")); + " it must be DP_ENCRYPT or NOT_ENCRYPT or PW_ENCRYPT or SIGNDS, please check it before setting")); throw new IllegalArgumentException(); } this.encryptLevel = encryptLevel; diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java index 38c44ad4d84..63287dd5888 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java @@ -23,6 +23,8 @@ import mindspore.schema.FeatureMap; import java.security.SecureRandom; import java.util.ArrayList; import java.util.Map; +import java.util.List; +import java.util.HashMap; import java.util.logging.Logger; /** @@ -46,6 +48,12 @@ public class SecureProtocol { private double dpNormClip; private ArrayList updateFeatureName = new ArrayList(); private int retCode; + private float signK; + private float signEps; + private float signThrRatio; + private float signGlobalLr; + private int signDimOut; + /** * Obtain current status code in client. @@ -102,6 +110,22 @@ public class SecureProtocol { return FLClientStatus.SUCCESS; } + /** + * Setting parameters for dimension select. + * + * @param map model weights. + * @return the status code corresponding to the response message. + */ + public FLClientStatus setDSParameter(float signK, float signEps, float signThrRatio, float signGlobalLr, int signDimOut, Map map) { + this.signK = signK; + this.signEps = signEps; + this.signThrRatio = signThrRatio; + this.signGlobalLr = signGlobalLr; + this.signDimOut = signDimOut; + this.modelMap = map; + return FLClientStatus.SUCCESS; + } + /** * Obtain the feature names that needed to be encrypted. * @@ -367,6 +391,10 @@ public class SecureProtocol { } } updateL2Norm = Math.sqrt(updateL2Norm); + if (updateL2Norm == 0) { + LOGGER.severe(Common.addTag("[Encrypt] updateL2Norm is 0, please check")); + return new int[0]; + } double clipFactor = Math.min(1.0, dpNormClip / updateL2Norm); // clip and add noise @@ -412,4 +440,282 @@ public class SecureProtocol { } return featuresMap; } + + /** + * The number of combinations of n things taken k. + * + * @param n Number of things. + * @param k Number of elements taken. + * @return the total number of "n choose k" combinations. + */ + private static double comb(double n, double k) { + boolean cond = (k <= n) && (n >= 0) && (k >= 0); + double m = n + 1; + if (!cond) { + return 0; + } else { + double nTerm = Math.min(k, n - k); + double res = 1; + for (int i = 1; i <= nTerm; i++) { + res *= (m - i); + res /= i; + } + return res; + } + } + + /** + * Calculate the number of possible combinations of output set given the number of topk dimensions. + * c(k, v) * c(d-k, h-v) + * + * @param numInter the number of dimensions from topk set. + * @param topkDim the size of top-k set. + * @param inputDim total number of dimensions in the model. + * @param outputDim the number of dimensions selected for constructing sparse local updates. + * @return the number of possible combinations of output set. + */ + private static double countCombs(int numInter, int topkDim, int inputDim, int outputDim) { + return comb(topkDim, numInter) * comb(inputDim - topkDim, outputDim - numInter); + } + + /** + * Calculate the probability mass function of the number of topk dimensions in the output set. + * v is the number of dimensions from topk set. + * + * @param thr threshold of the number of topk dimensions in the output set. + * @param topkDim the size of top-k set. + * @param inputDim total number of dimensions in the model. + * @param outputDim the number of dimensions selected for constructing sparse local updates. + * @param eps the privacy budget of SignDS alg. + * @return the probability mass function. + */ + private static List calcPmf(int thr, int topkDim, int inputDim, int outputDim, float eps) { + List pmf = new ArrayList<>(); + double newPmf; + for (int v = 0; v <= outputDim; v++) { + if (v < thr) { + newPmf = countCombs(v, topkDim, inputDim, outputDim); + } else { + newPmf = countCombs(v, topkDim, inputDim, outputDim) * Math.exp(eps); + } + pmf.add(newPmf); + } + double pmfSum = 0; + for (int i = 0; i < pmf.size(); i++) { + pmfSum += pmf.get(i); + } + if (pmfSum == 0) { + LOGGER.severe(Common.addTag("[SignDS] probability mass function is 0, please check")); + return new ArrayList<>(); + } + for (int i = 0; i < pmf.size(); i++) { + pmf.set(i, pmf.get(i) / pmfSum); + } + return pmf; + } + + /** + * Calculate the expected number of topk dimensions in the output set given outputDim. + * The size of pmf is also outputDim. + * + * @param pmf probability mass function + * @return the expectation of the topk dimensions in the output set. + */ + private static double calcExpectation(List pmf) { + double sumExpectation = 0; + for (int i = 0; i < pmf.size(); i++) { + sumExpectation += (i * pmf.get(i)); + } + return sumExpectation; + } + + /** + * Calculate the optimum threshold for the number of topk dimension in the output set. + * The optimum threshold is an integer among [1, outputDim], which has the largest + * expectation value. + * + * @param topkDim the size of top-k set. + * @param inputDim total number of dimensions in the model. + * @param outputDim the number of dimensions selected for constructing sparse local updates. + * @param eps the privacy budget of SignDS alg. + * @return the optimum threshold. + */ + private static int calcOptThr(int topkDim, int inputDim, int outputDim, float eps) { + double optExpect = 0; + double optT = 0; + for (int t = 1; t <= outputDim; t++) { + double newExpect = calcExpectation(calcPmf(t, topkDim, inputDim, outputDim, eps)); + if (newExpect > optExpect) { + optExpect = newExpect; + optT = t; + } else { + break; + } + } + return (int) Math.max(optT, 1); + } + + /** + * Tool function for finding the optimum output dimension. + * The main idea is to iteratively search for the largest output dimension while + * ensuring the expected ratio of topk dimensions in the output set larger than + * the target ratio. + * + * @param thrInterRatio threshold of the expected ratio of topk dimensions + * @param topkDim the size of top-k set. + * @param inputDim total number of dimensions in the model. + * @param eps the privacy budget of SignDS alg. + * @return the optimum output dimension. + */ + private static int findOptOutputDim(float thrInterRatio, int topkDim, int inputDim, float eps) { + int outputDim = 1; + while (true) { + int thr = calcOptThr(topkDim, inputDim, outputDim, eps); + double expectedRatio = calcExpectation(calcPmf(thr, topkDim, inputDim, outputDim, eps)) / outputDim; + if (expectedRatio < thrInterRatio || Double.isNaN(expectedRatio)) { + break; + } else { + outputDim += 1; + } + } + return Math.max(1, (outputDim - 1)); + } + + /** + * Determine the number of dimensions to be sampled from the topk dimension set via + * inverse sampling. + * The main steps of the trick of inverse sampling include: + * 1. Sample a random probability from the uniform distribution U(0, 1). + * 2. Calculate the cumulative distribution of numInter, namely the number of + * topk dimensions in the output set. + * 3. Compare the cumulative distribution with the random probability and determine + * the value of numInter. + * + * @param thrDim threshold of the number of topk dimensions in the output set. + * @param denominator calculate denominator given the threshold. + * @param topkDim the size of top-k set. + * @param inputDim total number of dimensions in the model. + * @param outputDim the number of dimensions selected for constructing sparse local updates. + * @param eps the privacy budget of SignDS alg. + * @return the number of dimensions to be sampled from the top-k dimension set. + */ + private static int countInters(int thrDim, double denominator, int topkDim, int inputDim, int outputDim, float eps) { + SecureRandom secureRandom = new SecureRandom(); + double randomProb = secureRandom.nextDouble(); + int numInter = 0; + double prob = countCombs(numInter, topkDim, inputDim, outputDim) / denominator; + while (prob < randomProb) { + numInter += 1; + if (numInter < thrDim) { + prob += countCombs(numInter, topkDim, inputDim, outputDim) / denominator; + } else { + prob += Math.exp(eps) * countCombs(numInter, topkDim, inputDim, outputDim) / denominator; + } + } + return numInter; + } + + /** + * SignDS model weights. + * + * @param builder the FlatBufferBuilder object used for serialization model weights. + * @param trainDataSize tne size of train data set. + * @return the serialized model weights after adding masks. + */ + + public int[] signDSModel(FlatBufferBuilder builder, int trainDataSize, Map trainedMap) { + Map mapBeforeTrain = modelMap; + int layerNum = updateFeatureName.size(); + int[] featuresMap = new int[layerNum]; + SecureRandom secureRandom = Common.getSecureRandom(); + boolean sign = secureRandom.nextBoolean(); + List nonTopkKeyList = new ArrayList<>(); + List topkKeyList = new ArrayList<>(); + Map allUpdateMap = new HashMap<>(); + for (int i = 0; i < layerNum; i++) { + String key = updateFeatureName.get(i); + float[] dataAfterTrain = trainedMap.get(key); + float[] dataBeforeTrain = mapBeforeTrain.get(key); + for (int j = 0; j < dataAfterTrain.length; j++) { + float updateData = dataAfterTrain[j] - dataBeforeTrain[j]; + String ij = Integer.toString(i) + ',' + j; + allUpdateMap.put(ij, updateData); + } + } + int inputDim = allUpdateMap.size(); + int topkDim = (int) (signK * inputDim); + if (signDimOut == 0) { + signDimOut = findOptOutputDim(signThrRatio, topkDim, inputDim, signEps); + } + int thrDim = calcOptThr(topkDim, inputDim, signDimOut, signEps); + double combLessInter = 0d; + double combMoreInter = 0d; + for (int i = 0; i < thrDim; i++) { + combLessInter += countCombs(i, topkDim, inputDim, signDimOut); + } + for (int i = thrDim; i <= signDimOut; i++) { + combMoreInter += countCombs(i, topkDim, inputDim, signDimOut); + } + double denominator = combLessInter + Math.exp(signEps) * combMoreInter; + if (denominator == 0) { + LOGGER.severe(Common.addTag("[SignDS] denominator is 0, please check")); + return new int[0]; + } + int numInter = countInters(thrDim, denominator, topkDim, inputDim, signDimOut, signEps); + int numOuter = signDimOut - numInter; + if (topkDim < numInter || signDimOut <= 0) { + LOGGER.severe("[SignDS] topkDim or signDimOut is ERROR! please check"); + return new int[0]; + } + + List> allUpdateList = new ArrayList<>(allUpdateMap.entrySet()); + if (sign) { + allUpdateList.sort((o1, o2) -> Float.compare(o2.getValue(), o1.getValue())); + } else { + allUpdateList.sort((o1, o2) -> Float.compare(o1.getValue(), o2.getValue())); + } + for (int i = 0; i < topkDim; i++) { + topkKeyList.add(allUpdateList.get(i).getKey()); + } + for (int i = topkDim; i < allUpdateList.size(); i++) { + nonTopkKeyList.add(allUpdateList.get(i).getKey()); + } + List outputDimensionIJStringList = new ArrayList<>(); + for (int i = topkKeyList.size(); i > topkKeyList.size() - numInter; i--) { + int randomIndex = secureRandom.nextInt(i); + String randomChoiceTopkIJString = topkKeyList.get(randomIndex); + topkKeyList.set(randomIndex, topkKeyList.get(i - 1)); + topkKeyList.set(i - 1, randomChoiceTopkIJString); + outputDimensionIJStringList.add(randomChoiceTopkIJString); + } + for (int i = nonTopkKeyList.size(); i > nonTopkKeyList.size() - numOuter; i--) { + int randomIndex = secureRandom.nextInt(i); + String randomChoiceNonTopkIJString = nonTopkKeyList.get(randomIndex); + nonTopkKeyList.set(randomIndex, nonTopkKeyList.get(i - 1)); + nonTopkKeyList.set(i - 1, randomChoiceNonTopkIJString); + outputDimensionIJStringList.add(randomChoiceNonTopkIJString); + } + float signValue = sign ? 1f * signGlobalLr : -1f * signGlobalLr; + for (String ijString : outputDimensionIJStringList) { + String[] ij = ijString.split(","); + int iKeyIndex = Integer.parseInt(ij[0]); + int jDataIndex = Integer.parseInt(ij[1]); + String key = updateFeatureName.get(iKeyIndex); + float[] dataBeforeTrain = mapBeforeTrain.get(key); + dataBeforeTrain[jDataIndex] += signValue; + mapBeforeTrain.put(key, dataBeforeTrain); + } + for (int i = 0; i < layerNum; i++) { + String key = updateFeatureName.get(i); + float[] dataBeforeTrain = mapBeforeTrain.get(key); + for (int j = 0; j < dataBeforeTrain.length; j++) { + dataBeforeTrain[j] *= trainDataSize; + } + int featureName = builder.createString(key); + int weight = FeatureMap.createDataVector(builder, dataBeforeTrain); + int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight); + featuresMap[i] = featureMap; + } + return featuresMap; + } } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java index 230b223561b..5abd708b79d 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java @@ -310,6 +310,18 @@ public class UpdateModel { this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP); LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!")); return this; + case SIGNDS: + int[] fmOffsetsSignDS = secureProtocol.signDSModel(builder, trainDataSize, trainedMap); + if (fmOffsetsSignDS == null || fmOffsetsSignDS.length == 0) { + LOGGER.severe("[Encrypt] the return fmOffsetsSignDS from is " + + "null, please check"); + retCode = ResponseCode.RequestError; + status = FLClientStatus.FAILED; + throw new IllegalArgumentException(); + } + this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsSignDS); + LOGGER.info(Common.addTag("[Encrypt] SignDS mask model ok!")); + return this; case NOT_ENCRYPT: default: int featureSize = updateFeatureName.size(); diff --git a/mindspore/python/mindspore/context.py b/mindspore/python/mindspore/context.py index 8e622f2c14f..550dc33324a 100644 --- a/mindspore/python/mindspore/context.py +++ b/mindspore/python/mindspore/context.py @@ -1072,11 +1072,18 @@ def set_fl_context(**kwargs): dp_norm_clip (float): A factor used for clipping model's weights for differential mechanism. Its value is suggested to be 0.5~2. Default: 1.0. encrypt_type (string): Secure schema for federated learning, which can be 'NOT_ENCRYPT', 'DP_ENCRYPT', - 'PW_ENCRYPT' or 'STABLE_PW_ENCRYPT'. If 'DP_ENCRYPT', differential privacy schema would be applied + 'PW_ENCRYPT', 'STABLE_PW_ENCRYPT' or 'SIGNDS'. If 'DP_ENCRYPT', differential privacy schema would be applied for clients and the privacy protection effect would be determined by dp_eps, dp_delta and dp_norm_clip as described above. If 'PW_ENCRYPT', pairwise secure aggregation would be applied to protect clients' model from stealing in cross-device scenario. If 'STABLE_PW_ENCRYPT', pairwise secure aggregation would - be applied to protect clients' model from stealing in cross-silo scenario. Default: 'NOT_ENCRYPT'. + be applied to protect clients' model from stealing in cross-silo scenario. If 'SIGNDS', SignDS schema would + be applied for clients. Default: 'NOT_ENCRYPT'. + sign_k (float): SignDS: Top-k ratio, namely the number of top-k dimensions divided by the total number of + dimensions. Default: 0.01. + sign_eps (float): SignDS: Privacy budget. Default: 100. + sign_thr_ratio (float): SignDS: Threshold of the expected topk dimension. Default: 0.6. + sign_global_lr (float): SignDS: The constant value assigned to the selected dimension. Default: 1. + sign_dim_out (int): SignDS: Number of output dimensions. Default: 0. config_file_path (string): Configuration file path used by recovery. Default: ''. scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202. enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: False. diff --git a/mindspore/python/mindspore/parallel/_ps_context.py b/mindspore/python/mindspore/parallel/_ps_context.py index e49c1fb3d1d..815bccbfc6a 100644 --- a/mindspore/python/mindspore/parallel/_ps_context.py +++ b/mindspore/python/mindspore/parallel/_ps_context.py @@ -72,7 +72,12 @@ _set_ps_context_func_map = { "dp_norm_clip": ps_context().set_dp_norm_clip, "encrypt_type": ps_context().set_encrypt_type, "http_url_prefix": ps_context().set_http_url_prefix, - "global_iteration_time_window": ps_context().set_global_iteration_time_window + "global_iteration_time_window": ps_context().set_global_iteration_time_window, + "sign_k": ps_context().set_sign_k, + "sign_eps": ps_context().set_sign_eps, + "sign_thr_ratio": ps_context().set_sign_thr_ratio, + "sign_global_lr": ps_context().set_sign_global_lr, + "sign_dim_out": ps_context().set_sign_dim_out } _get_ps_context_func_map = { @@ -114,7 +119,12 @@ _get_ps_context_func_map = { "scheduler_manage_port": ps_context().scheduler_manage_port, "config_file_path": ps_context().config_file_path, "http_url_prefix": ps_context().http_url_prefix, - "global_iteration_time_window": ps_context().global_iteration_time_window + "global_iteration_time_window": ps_context().global_iteration_time_window, + "sign_k": ps_context().sign_k, + "sign_eps": ps_context().sign_eps, + "sign_thr_ratio": ps_context().sign_thr_ratio, + "sign_global_lr": ps_context().sign_global_lr, + "sign_dim_out": ps_context().sign_dim_out } _check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port", diff --git a/mindspore/schema/cipher.fbs b/mindspore/schema/cipher.fbs index 45009b1a14a..611dd871460 100644 --- a/mindspore/schema/cipher.fbs +++ b/mindspore/schema/cipher.fbs @@ -16,15 +16,32 @@ namespace mindspore.schema; -table CipherPublicParams { +table PWParams { t:int; p:[ubyte]; g:int; prime:[ubyte]; +} + +table DPParams { dp_eps:float; dp_delta:float; dp_norm_clip:float; +} + +table DSParams { + sign_k:float; + sign_eps:float; + sign_thr_ratio:float; + sign_global_lr:float; + sign_dim_out:int; +} + +table CipherPublicParams { encrypt_type:string; + pw_params:PWParams; + dp_params:DPParams; + ds_params:DSParams; } table ClientPublicKeys { diff --git a/tests/st/fl/albert/cloud_train.py b/tests/st/fl/albert/cloud_train.py index 199bdea22fd..3b6b15ce5f0 100644 --- a/tests/st/fl/albert/cloud_train.py +++ b/tests/st/fl/albert/cloud_train.py @@ -76,6 +76,12 @@ def parse_args(): parser.add_argument("--root_second_ca_path", type=str, default="") parser.add_argument("--equip_crl_path", type=str, default="") parser.add_argument("--replay_attack_time_diff", type=int, default=600000) + # parameters for 'SIGNDS' + parser.add_argument("--sign_k", type=float, default=0.01) + parser.add_argument("--sign_eps", type=float, default=100) + parser.add_argument("--sign_thr_ratio", type=float, default=0.6) + parser.add_argument("--sign_global_lr", type=float, default=0.1) + parser.add_argument("--sign_dim_out", type=int, default=0) return parser.parse_args() @@ -118,6 +124,11 @@ def server_train(args): root_second_ca_path = args.root_second_ca_path equip_crl_path = args.equip_crl_path replay_attack_time_diff = args.replay_attack_time_diff + sign_k = args.sign_k + sign_eps = args.sign_eps + sign_thr_ratio = args.sign_thr_ratio + sign_global_lr = args.sign_global_lr + sign_dim_out = args.sign_dim_out # Replace some parameters with federated learning parameters. train_cfg.max_global_epoch = fl_iteration_num @@ -155,7 +166,12 @@ def server_train(args): "root_first_ca_path": root_first_ca_path, "root_second_ca_path": root_second_ca_path, "equip_crl_path": equip_crl_path, - "replay_attack_time_diff": replay_attack_time_diff + "replay_attack_time_diff": replay_attack_time_diff, + "sign_k": sign_k, + "sign_eps": sign_eps, + "sign_thr_ratio": sign_thr_ratio, + "sign_global_lr": sign_global_lr, + "sign_dim_out": sign_dim_out } if not os.path.exists(output_dir): diff --git a/tests/st/fl/albert/run_hybrid_train_server.py b/tests/st/fl/albert/run_hybrid_train_server.py index 3c2b2182a5c..e19827f3170 100644 --- a/tests/st/fl/albert/run_hybrid_train_server.py +++ b/tests/st/fl/albert/run_hybrid_train_server.py @@ -58,6 +58,12 @@ parser.add_argument("--root_first_ca_path", type=str, default="") parser.add_argument("--root_second_ca_path", type=str, default="") parser.add_argument("--equip_crl_path", type=str, default="") parser.add_argument("--replay_attack_time_diff", type=int, default=600000) +# parameters for 'SIGNDS' +parser.add_argument("--sign_k", type=float, default=0.01) +parser.add_argument("--sign_eps", type=float, default=100) +parser.add_argument("--sign_thr_ratio", type=float, default=0.6) +parser.add_argument("--sign_global_lr", type=float, default=0.1) +parser.add_argument("--sign_dim_out", type=int, default=0) args, _ = parser.parse_known_args() device_target = args.device_target @@ -93,6 +99,11 @@ root_first_ca_path = args.root_first_ca_path root_second_ca_path = args.root_second_ca_path equip_crl_path = args.equip_crl_path replay_attack_time_diff = args.replay_attack_time_diff +sign_k = args.sign_k +sign_eps = args.sign_eps +sign_thr_ratio = args.sign_thr_ratio +sign_global_lr = args.sign_global_lr +sign_dim_out = args.sign_dim_out if local_server_num == -1: local_server_num = server_num @@ -139,6 +150,11 @@ for i in range(local_server_num): cmd_server += " --root_second_ca_path=" + str(root_second_ca_path) cmd_server += " --equip_crl_path=" + str(equip_crl_path) cmd_server += " --replay_attack_time_diff=" + str(replay_attack_time_diff) + cmd_server += " --sign_k=" + str(sign_k) + cmd_server += " --sign_eps=" + str(sign_eps) + cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio) + cmd_server += " --sign_global_lr=" + str(sign_global_lr) + cmd_server += " --sign_dim_out=" + str(sign_dim_out) cmd_server += " > server.log 2>&1 &" import time diff --git a/tests/st/fl/cross_device_lenet/cloud/run_lenet_server.py b/tests/st/fl/cross_device_lenet/cloud/run_lenet_server.py index 098c7ec47ac..f03fe2e19f9 100644 --- a/tests/st/fl/cross_device_lenet/cloud/run_lenet_server.py +++ b/tests/st/fl/cross_device_lenet/cloud/run_lenet_server.py @@ -49,6 +49,12 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) parser.add_argument("--client_password", type=str, default="") parser.add_argument("--server_password", type=str, default="") parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) +# parameters for 'SIGNDS' +parser.add_argument("--sign_k", type=float, default=0.01) +parser.add_argument("--sign_eps", type=float, default=100) +parser.add_argument("--sign_thr_ratio", type=float, default=0.6) +parser.add_argument("--sign_global_lr", type=float, default=0.1) +parser.add_argument("--sign_dim_out", type=int, default=0) if __name__ == "__main__": args, _ = parser.parse_known_args() @@ -80,6 +86,11 @@ if __name__ == "__main__": client_password = args.client_password server_password = args.server_password enable_ssl = args.enable_ssl + sign_k = args.sign_k + sign_eps = args.sign_eps + sign_thr_ratio = args.sign_thr_ratio + sign_global_lr = args.sign_global_lr + sign_dim_out = args.sign_dim_out if local_server_num == -1: local_server_num = server_num @@ -121,6 +132,11 @@ if __name__ == "__main__": cmd_server += " --server_password=" + str(server_password) cmd_server += " --enable_ssl=" + str(enable_ssl) cmd_server += " --encrypt_type=" + str(encrypt_type) + cmd_server += " --sign_k=" + str(sign_k) + cmd_server += " --sign_eps=" + str(sign_eps) + cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio) + cmd_server += " --sign_global_lr=" + str(sign_global_lr) + cmd_server += " --sign_dim_out=" + str(sign_dim_out) cmd_server += " > server.log 2>&1 &" import time diff --git a/tests/st/fl/cross_device_lenet/cloud/test_lenet.py b/tests/st/fl/cross_device_lenet/cloud/test_lenet.py index 9a84c60f080..c9d44469df2 100644 --- a/tests/st/fl/cross_device_lenet/cloud/test_lenet.py +++ b/tests/st/fl/cross_device_lenet/cloud/test_lenet.py @@ -56,6 +56,12 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) parser.add_argument("--client_password", type=str, default="") parser.add_argument("--server_password", type=str, default="") parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) +# parameters for 'SIGNDS' +parser.add_argument("--sign_k", type=float, default=0.01) +parser.add_argument("--sign_eps", type=float, default=100) +parser.add_argument("--sign_thr_ratio", type=float, default=0.6) +parser.add_argument("--sign_global_lr", type=float, default=0.1) +parser.add_argument("--sign_dim_out", type=int, default=0) args, _ = parser.parse_known_args() device_target = args.device_target @@ -87,6 +93,11 @@ encrypt_type = args.encrypt_type client_password = args.client_password server_password = args.server_password enable_ssl = args.enable_ssl +sign_k = args.sign_k +sign_eps = args.sign_eps +sign_thr_ratio = args.sign_thr_ratio +sign_global_lr = args.sign_global_lr +sign_dim_out = args.sign_dim_out ctx = { "enable_fl": True, @@ -117,7 +128,12 @@ ctx = { "encrypt_type": encrypt_type, "client_password": client_password, "server_password": server_password, - "enable_ssl": enable_ssl + "enable_ssl": enable_ssl, + "sign_k": sign_k, + "sign_eps": sign_eps, + "sign_thr_ratio": sign_thr_ratio, + "sign_global_lr": sign_global_lr, + "sign_dim_out": sign_dim_out } context.set_context(mode=context.GRAPH_MODE, device_target=device_target) diff --git a/tests/st/fl/hybrid_lenet/run_hybrid_train_server.py b/tests/st/fl/hybrid_lenet/run_hybrid_train_server.py index d0ccd71c989..195a76203cb 100644 --- a/tests/st/fl/hybrid_lenet/run_hybrid_train_server.py +++ b/tests/st/fl/hybrid_lenet/run_hybrid_train_server.py @@ -57,6 +57,12 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) parser.add_argument("--client_password", type=str, default="") parser.add_argument("--server_password", type=str, default="") parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) +# parameters for 'SIGNDS' +parser.add_argument("--sign_k", type=float, default=0.01) +parser.add_argument("--sign_eps", type=float, default=100) +parser.add_argument("--sign_thr_ratio", type=float, default=0.6) +parser.add_argument("--sign_global_lr", type=float, default=0.1) +parser.add_argument("--sign_dim_out", type=int, default=0) args, _ = parser.parse_known_args() device_target = args.device_target @@ -92,6 +98,11 @@ root_first_ca_path = args.root_first_ca_path root_second_ca_path = args.root_second_ca_path equip_crl_path = args.equip_crl_path replay_attack_time_diff = args.replay_attack_time_diff +sign_k = args.sign_k +sign_eps = args.sign_eps +sign_thr_ratio = args.sign_thr_ratio +sign_global_lr = args.sign_global_lr +sign_dim_out = args.sign_dim_out if local_server_num == -1: local_server_num = server_num @@ -138,6 +149,11 @@ for i in range(local_server_num): cmd_server += " --root_second_ca_path=" + str(root_second_ca_path) cmd_server += " --equip_crl_path=" + str(equip_crl_path) cmd_server += " --replay_attack_time_diff=" + str(replay_attack_time_diff) + cmd_server += " --sign_k=" + str(sign_k) + cmd_server += " --sign_eps=" + str(sign_eps) + cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio) + cmd_server += " --sign_global_lr=" + str(sign_global_lr) + cmd_server += " --sign_dim_out=" + str(sign_dim_out) cmd_server += " > server.log 2>&1 &" import time diff --git a/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py b/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py index 817a51be035..65f27b8d4ab 100644 --- a/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py +++ b/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py @@ -65,6 +65,12 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) parser.add_argument("--client_password", type=str, default="") parser.add_argument("--server_password", type=str, default="") parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) +# parameters for 'SIGNDS' +parser.add_argument("--sign_k", type=float, default=0.01) +parser.add_argument("--sign_eps", type=float, default=100) +parser.add_argument("--sign_thr_ratio", type=float, default=0.6) +parser.add_argument("--sign_global_lr", type=float, default=0.1) +parser.add_argument("--sign_dim_out", type=int, default=0) args, _ = parser.parse_known_args() device_target = args.device_target @@ -102,6 +108,11 @@ root_first_ca_path = args.root_first_ca_path root_second_ca_path = args.root_second_ca_path equip_crl_path = args.equip_crl_path replay_attack_time_diff = args.replay_attack_time_diff +sign_k = args.sign_k +sign_eps = args.sign_eps +sign_thr_ratio = args.sign_thr_ratio +sign_global_lr = args.sign_global_lr +sign_dim_out = args.sign_dim_out ctx = { "enable_fl": True, @@ -138,7 +149,12 @@ ctx = { "encrypt_type": encrypt_type, "client_password": client_password, "server_password": server_password, - "enable_ssl": enable_ssl + "enable_ssl": enable_ssl, + "sign_k": sign_k, + "sign_eps": sign_eps, + "sign_thr_ratio": sign_thr_ratio, + "sign_global_lr": sign_global_lr, + "sign_dim_out": sign_dim_out } context.set_context(mode=context.GRAPH_MODE, device_target=device_target) diff --git a/tests/st/fl/mobile/run_mobile_server.py b/tests/st/fl/mobile/run_mobile_server.py index 42ad98bfd1e..8d2cb049c01 100644 --- a/tests/st/fl/mobile/run_mobile_server.py +++ b/tests/st/fl/mobile/run_mobile_server.py @@ -55,6 +55,12 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) parser.add_argument("--client_password", type=str, default="") parser.add_argument("--server_password", type=str, default="") parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) +# parameters for 'SIGNDS' +parser.add_argument("--sign_k", type=float, default=0.01) +parser.add_argument("--sign_eps", type=float, default=100) +parser.add_argument("--sign_thr_ratio", type=float, default=0.6) +parser.add_argument("--sign_global_lr", type=float, default=0.1) +parser.add_argument("--sign_dim_out", type=int, default=0) if __name__ == "__main__": args, _ = parser.parse_known_args() @@ -91,6 +97,11 @@ if __name__ == "__main__": client_password = args.client_password server_password = args.server_password enable_ssl = args.enable_ssl + sign_k = args.sign_k + sign_eps = args.sign_eps + sign_thr_ratio = args.sign_thr_ratio + sign_global_lr = args.sign_global_lr + sign_dim_out = args.sign_dim_out if local_server_num == -1: local_server_num = server_num @@ -137,6 +148,11 @@ if __name__ == "__main__": cmd_server += " --replay_attack_time_diff=" + str(replay_attack_time_diff) cmd_server += " --enable_ssl=" + str(enable_ssl) cmd_server += " --encrypt_type=" + str(encrypt_type) + cmd_server += " --sign_k=" + str(sign_k) + cmd_server += " --sign_eps=" + str(sign_eps) + cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio) + cmd_server += " --sign_global_lr=" + str(sign_global_lr) + cmd_server += " --sign_dim_out=" + str(sign_dim_out) cmd_server += " > server.log 2>&1 &" import time diff --git a/tests/st/fl/mobile/test_mobile_lenet.py b/tests/st/fl/mobile/test_mobile_lenet.py index 163e435c0c3..2235060c549 100644 --- a/tests/st/fl/mobile/test_mobile_lenet.py +++ b/tests/st/fl/mobile/test_mobile_lenet.py @@ -64,6 +64,12 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) parser.add_argument("--client_password", type=str, default="") parser.add_argument("--server_password", type=str, default="") parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) +# parameters for 'SIGNDS' +parser.add_argument("--sign_k", type=float, default=0.01) +parser.add_argument("--sign_eps", type=float, default=100) +parser.add_argument("--sign_thr_ratio", type=float, default=0.6) +parser.add_argument("--sign_global_lr", type=float, default=0.1) +parser.add_argument("--sign_dim_out", type=int, default=0) args, _ = parser.parse_known_args() device_target = args.device_target @@ -100,6 +106,11 @@ replay_attack_time_diff = args.replay_attack_time_diff client_password = args.client_password server_password = args.server_password enable_ssl = args.enable_ssl +sign_k = args.sign_k +sign_eps = args.sign_eps +sign_thr_ratio = args.sign_thr_ratio +sign_global_lr = args.sign_global_lr +sign_dim_out = args.sign_dim_out ctx = { "enable_fl": True, @@ -135,7 +146,12 @@ ctx = { "encrypt_type": encrypt_type, "client_password": client_password, "server_password": server_password, - "enable_ssl": enable_ssl + "enable_ssl": enable_ssl, + "sign_k": sign_k, + "sign_eps": sign_eps, + "sign_thr_ratio": sign_thr_ratio, + "sign_global_lr": sign_global_lr, + "sign_dim_out": sign_dim_out } context.set_context(mode=context.GRAPH_MODE, device_target=device_target)