signds
This commit is contained in:
parent
c72a1146fb
commit
09c938a584
|
@ -55,11 +55,21 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size
|
|||
publicparam_.dp_delta = param.dp_delta;
|
||||
publicparam_.dp_norm_clip = param.dp_norm_clip;
|
||||
publicparam_.encrypt_type = param.encrypt_type;
|
||||
publicparam_.sign_k = param.sign_k;
|
||||
publicparam_.sign_eps = param.sign_eps;
|
||||
publicparam_.sign_thr_ratio = param.sign_thr_ratio;
|
||||
publicparam_.sign_global_lr = param.sign_global_lr;
|
||||
publicparam_.sign_dim_out = param.sign_dim_out;
|
||||
|
||||
if (param.encrypt_type == mindspore::ps::kDPEncryptType) {
|
||||
MS_LOG(INFO) << "DP parameters init, dp_eps: " << param.dp_eps;
|
||||
MS_LOG(INFO) << "DP parameters init, dp_delta: " << param.dp_delta;
|
||||
MS_LOG(INFO) << "DP parameters init, dp_norm_clip: " << param.dp_norm_clip;
|
||||
MS_LOG(INFO) << "DP parameters init, dp_eps: " << param.dp_eps << ", dp_delta: " << param.dp_delta
|
||||
<< ", dp_norm_clip: " << param.dp_norm_clip;
|
||||
}
|
||||
|
||||
if (param.encrypt_type == mindspore::ps::kDSEncryptType) {
|
||||
MS_LOG(INFO) << "Sign parameters init, sign_k: " << param.sign_k << ", sign_eps: " << param.sign_eps
|
||||
<< ", sign_thr_ratio: " << param.sign_thr_ratio << ", sign_global_lr: " << param.sign_global_lr
|
||||
<< ", sign_dim_out: " << param.sign_dim_out;
|
||||
}
|
||||
|
||||
if (param.encrypt_type == mindspore::ps::kPWEncryptType) {
|
||||
|
|
|
@ -68,6 +68,11 @@ struct CipherPublicPara {
|
|||
float dp_delta;
|
||||
float dp_norm_clip;
|
||||
string encrypt_type;
|
||||
float sign_k;
|
||||
float sign_eps;
|
||||
float sign_thr_ratio;
|
||||
float sign_global_lr;
|
||||
int sign_dim_out;
|
||||
};
|
||||
|
||||
class CipherMetaStorage {
|
||||
|
|
|
@ -330,9 +330,17 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
|||
float dp_delta = param->dp_delta;
|
||||
float dp_norm_clip = param->dp_norm_clip;
|
||||
auto encrypt_type = fbb->CreateString(ps::PSContext::instance()->encrypt_type());
|
||||
float sign_k = param->sign_k;
|
||||
float sign_eps = param->sign_eps;
|
||||
float sign_thr_ratio = param->sign_thr_ratio;
|
||||
float sign_global_lr = param->sign_global_lr;
|
||||
int sign_dim_out = param->sign_dim_out;
|
||||
|
||||
auto pw_params = schema::CreatePWParams(*fbb.get(), t, p, g, prime);
|
||||
auto dp_params = schema::CreateDPParams(*fbb.get(), dp_eps, dp_delta, dp_norm_clip);
|
||||
auto ds_params = schema::CreateDSParams(*fbb.get(), sign_k, sign_eps, sign_thr_ratio, sign_global_lr, sign_dim_out);
|
||||
auto cipher_public_params =
|
||||
schema::CreateCipherPublicParams(*fbb.get(), t, p, g, prime, dp_eps, dp_delta, dp_norm_clip, encrypt_type);
|
||||
schema::CreateCipherPublicParams(*fbb.get(), encrypt_type, pw_params, dp_params, ds_params);
|
||||
#endif
|
||||
|
||||
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
|
||||
|
|
|
@ -256,6 +256,11 @@ void Server::InitCipher() {
|
|||
float dp_delta = ps::PSContext::instance()->dp_delta();
|
||||
float dp_norm_clip = ps::PSContext::instance()->dp_norm_clip();
|
||||
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
|
||||
float sign_k = ps::PSContext::instance()->sign_k();
|
||||
float sign_eps = ps::PSContext::instance()->sign_eps();
|
||||
float sign_thr_ratio = ps::PSContext::instance()->sign_thr_ratio();
|
||||
float sign_global_lr = ps::PSContext::instance()->sign_global_lr();
|
||||
int sign_dim_out = ps::PSContext::instance()->sign_dim_out();
|
||||
|
||||
mindspore::armour::CipherPublicPara param;
|
||||
param.g = cipher_g;
|
||||
|
@ -268,6 +273,11 @@ void Server::InitCipher() {
|
|||
param.dp_eps = dp_eps;
|
||||
param.dp_norm_clip = dp_norm_clip;
|
||||
param.encrypt_type = encrypt_type;
|
||||
param.sign_k = sign_k;
|
||||
param.sign_eps = sign_eps;
|
||||
param.sign_thr_ratio = sign_thr_ratio;
|
||||
param.sign_global_lr = sign_global_lr;
|
||||
param.sign_dim_out = sign_dim_out;
|
||||
|
||||
BIGNUM *prim = BN_new();
|
||||
if (prim == NULL) {
|
||||
|
|
|
@ -468,12 +468,18 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
"Set configuration files required by the communication layer.")
|
||||
.def("config_file_path", &PSContext::config_file_path,
|
||||
"Get configuration files required by the communication layer.")
|
||||
.def("set_dp_eps", &PSContext::set_dp_eps, "Set dp epsilon for federated learning secure aggregation.")
|
||||
.def("set_dp_delta", &PSContext::set_dp_delta, "Set dp delta for federated learning secure aggregation.")
|
||||
.def("set_dp_norm_clip", &PSContext::set_dp_norm_clip,
|
||||
"Set dp norm clip for federated learning secure aggregation.")
|
||||
.def("set_encrypt_type", &PSContext::set_encrypt_type,
|
||||
"Set encrypt type for federated learning secure aggregation.")
|
||||
.def("set_sign_k", &PSContext::set_sign_k, "Set sign k for federated learning SignDS.")
|
||||
.def("sign_k", &PSContext::sign_k, "Get sign k for federated learning SignDS.")
|
||||
.def("set_sign_eps", &PSContext::set_sign_eps, "Set sign eps for federated learning SignDS.")
|
||||
.def("sign_eps", &PSContext::sign_eps, "Get sign eps for federated learning SignDS.")
|
||||
.def("set_sign_thr_ratio", &PSContext::set_sign_thr_ratio, "Set sign thr ratio for federated learning SignDS.")
|
||||
.def("sign_thr_ratio", &PSContext::sign_thr_ratio, "Get sign thr ratio for federated learning SignDS.")
|
||||
.def("set_sign_global_lr", &PSContext::set_sign_global_lr, "Set sign global lr for federated learning SignDS.")
|
||||
.def("sign_global_lr", &PSContext::sign_global_lr, "Get sign global lr for federated learning SignDS.")
|
||||
.def("set_sign_dim_out", &PSContext::set_sign_dim_out, "Set sign dim out for federated learning SignDS.")
|
||||
.def("sign_dim_out", &PSContext::sign_dim_out, "Get sign dim out for federated learning SignDS.")
|
||||
.def("set_http_url_prefix", &PSContext::set_http_url_prefix, "Set http url prefix for http communication.")
|
||||
.def("http_url_prefix", &PSContext::http_url_prefix, "http url prefix for http communication.");
|
||||
|
||||
|
|
|
@ -220,12 +220,14 @@ const std::string &PSContext::server_mode() const { return server_mode_; }
|
|||
|
||||
void PSContext::set_encrypt_type(const std::string &encrypt_type) {
|
||||
if (encrypt_type != kNotEncryptType && encrypt_type != kDPEncryptType && encrypt_type != kPWEncryptType &&
|
||||
encrypt_type != kStablePWEncryptType) {
|
||||
MS_LOG(EXCEPTION) << encrypt_type << " is invalid. Encrypt type must be " << kNotEncryptType << " or "
|
||||
<< kDPEncryptType << " or " << kPWEncryptType << " or " << kStablePWEncryptType;
|
||||
return;
|
||||
encrypt_type != kStablePWEncryptType && encrypt_type != kDSEncryptType) {
|
||||
MS_LOG(WARNING) << encrypt_type << " is invalid. Encrypt type must be " << kNotEncryptType << " or "
|
||||
<< kDPEncryptType << " or " << kPWEncryptType << " or " << kStablePWEncryptType << " or "
|
||||
<< kDSEncryptType << ", DP scheme is used by default.";
|
||||
encrypt_type_ = kDPEncryptType;
|
||||
} else {
|
||||
encrypt_type_ = encrypt_type;
|
||||
}
|
||||
encrypt_type_ = encrypt_type;
|
||||
}
|
||||
|
||||
const std::string &PSContext::encrypt_type() const { return encrypt_type_; }
|
||||
|
@ -234,8 +236,9 @@ void PSContext::set_dp_eps(float dp_eps) {
|
|||
if (dp_eps > 0) {
|
||||
dp_eps_ = dp_eps;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << dp_eps << " is invalid, dp_eps must be larger than 0.";
|
||||
return;
|
||||
MS_LOG(WARNING) << dp_eps << " is invalid, dp_eps must be larger than 0, 50 is used by default.";
|
||||
float dp_eps_default = 50;
|
||||
dp_eps_ = dp_eps_default;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -245,8 +248,9 @@ void PSContext::set_dp_delta(float dp_delta) {
|
|||
if (dp_delta > 0 && dp_delta < 1) {
|
||||
dp_delta_ = dp_delta;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << dp_delta << " is invalid, dp_delta must be in range of (0, 1).";
|
||||
return;
|
||||
MS_LOG(WARNING) << dp_delta << " is invalid, dp_delta must be in range of (0, 1), 0.01 is used by default.";
|
||||
float dp_delta_default = 0.01;
|
||||
dp_delta_ = dp_delta_default;
|
||||
}
|
||||
}
|
||||
float PSContext::dp_delta() const { return dp_delta_; }
|
||||
|
@ -255,12 +259,73 @@ void PSContext::set_dp_norm_clip(float dp_norm_clip) {
|
|||
if (dp_norm_clip > 0) {
|
||||
dp_norm_clip_ = dp_norm_clip;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << dp_norm_clip << " is invalid, dp_norm_clip must be larger than 0.";
|
||||
return;
|
||||
MS_LOG(WARNING) << dp_norm_clip << " is invalid, dp_norm_clip must be larger than 0, 1 is used by default.";
|
||||
float dp_norm_clip_default = 1;
|
||||
dp_norm_clip_ = dp_norm_clip_default;
|
||||
}
|
||||
}
|
||||
float PSContext::dp_norm_clip() const { return dp_norm_clip_; }
|
||||
|
||||
void PSContext::set_sign_k(float sign_k) {
|
||||
float sign_k_upper = 0.25;
|
||||
if (sign_k > 0 && sign_k <= sign_k_upper) {
|
||||
sign_k_ = sign_k;
|
||||
} else {
|
||||
MS_LOG(WARNING) << sign_k << " is invalid, sign_k must be in range of (0, 0.25], 0.01 is used by default.";
|
||||
float sign_k_default = 0.01;
|
||||
sign_k_ = sign_k_default;
|
||||
}
|
||||
}
|
||||
float PSContext::sign_k() const { return sign_k_; }
|
||||
|
||||
void PSContext::set_sign_eps(float sign_eps) {
|
||||
float sign_eps_upper = 100;
|
||||
if (sign_eps > 0 && sign_eps <= sign_eps_upper) {
|
||||
sign_eps_ = sign_eps;
|
||||
} else {
|
||||
MS_LOG(WARNING) << sign_eps << " is invalid, sign_eps must be in range of (0, 100], 100 is used by default.";
|
||||
float sign_eps_default = 100;
|
||||
sign_eps_ = sign_eps_default;
|
||||
}
|
||||
}
|
||||
float PSContext::sign_eps() const { return sign_eps_; }
|
||||
|
||||
void PSContext::set_sign_thr_ratio(float sign_thr_ratio) {
|
||||
float sign_thr_ratio_bound = 0.5;
|
||||
if (sign_thr_ratio >= sign_thr_ratio_bound && sign_thr_ratio <= 1) {
|
||||
sign_thr_ratio_ = sign_thr_ratio;
|
||||
} else {
|
||||
MS_LOG(WARNING) << sign_thr_ratio
|
||||
<< " is invalid, sign_thr_ratio must be in range of [0.5, 1], 0.6 is used by default.";
|
||||
float sign_thr_ratio_default = 0.6;
|
||||
sign_thr_ratio_ = sign_thr_ratio_default;
|
||||
}
|
||||
}
|
||||
float PSContext::sign_thr_ratio() const { return sign_thr_ratio_; }
|
||||
|
||||
void PSContext::set_sign_global_lr(float sign_global_lr) {
|
||||
if (sign_global_lr > 0) {
|
||||
sign_global_lr_ = sign_global_lr;
|
||||
} else {
|
||||
MS_LOG(WARNING) << sign_global_lr << " is invalid, sign_global_lr must be larger than 0, 1 is used by default.";
|
||||
float sign_global_lr_default = 1;
|
||||
sign_global_lr_ = sign_global_lr_default;
|
||||
}
|
||||
}
|
||||
float PSContext::sign_global_lr() const { return sign_global_lr_; }
|
||||
|
||||
void PSContext::set_sign_dim_out(int sign_dim_out) {
|
||||
int sign_dim_out_upper = 50;
|
||||
if (sign_dim_out >= 0 && sign_dim_out <= sign_dim_out_upper) {
|
||||
sign_dim_out_ = sign_dim_out;
|
||||
} else {
|
||||
MS_LOG(WARNING) << sign_dim_out << " is invalid, sign_dim_out must be in range of [0, 50], 0 is used by default.";
|
||||
float sign_dim_out_default = 0;
|
||||
sign_dim_out_ = sign_dim_out_default;
|
||||
}
|
||||
}
|
||||
int PSContext::sign_dim_out() const { return sign_dim_out_; }
|
||||
|
||||
void PSContext::set_ms_role(const std::string &role) {
|
||||
if (role != kEnvRoleOfWorker && role != kEnvRoleOfServer && role != kEnvRoleOfScheduler) {
|
||||
MS_LOG(EXCEPTION) << "ms_role " << role << " is invalid.";
|
||||
|
|
|
@ -39,6 +39,7 @@ constexpr char kDPEncryptType[] = "DP_ENCRYPT";
|
|||
constexpr char kPWEncryptType[] = "PW_ENCRYPT";
|
||||
constexpr char kStablePWEncryptType[] = "STABLE_PW_ENCRYPT";
|
||||
constexpr char kNotEncryptType[] = "NOT_ENCRYPT";
|
||||
constexpr char kDSEncryptType[] = "SIGNDS";
|
||||
|
||||
// Use binary data to represent federated learning server's context so that we can judge which round resets the
|
||||
// iteration. From right to left, each bit stands for:
|
||||
|
@ -173,6 +174,21 @@ class PSContext {
|
|||
void set_dp_norm_clip(float dp_norm_clip);
|
||||
float dp_norm_clip() const;
|
||||
|
||||
void set_sign_k(float sign_k);
|
||||
float sign_k() const;
|
||||
|
||||
void set_sign_eps(float sign_eps);
|
||||
float sign_eps() const;
|
||||
|
||||
void set_sign_thr_ratio(float sign_thr_ratio);
|
||||
float sign_thr_ratio() const;
|
||||
|
||||
void set_sign_global_lr(float sign_global_lr);
|
||||
float sign_global_lr() const;
|
||||
|
||||
void set_sign_dim_out(int sign_dim_out);
|
||||
int sign_dim_out() const;
|
||||
|
||||
core::ClusterConfig &cluster_config();
|
||||
|
||||
void set_root_first_ca_path(const std::string &root_first_ca_path);
|
||||
|
@ -254,6 +270,11 @@ class PSContext {
|
|||
dp_delta_(0.01),
|
||||
dp_norm_clip_(1.0),
|
||||
encrypt_type_(kNotEncryptType),
|
||||
sign_k_(0.01),
|
||||
sign_eps_(100),
|
||||
sign_thr_ratio_(0.6),
|
||||
sign_global_lr_(1),
|
||||
sign_dim_out_(0),
|
||||
enable_ssl_(false),
|
||||
client_password_(""),
|
||||
server_password_(""),
|
||||
|
@ -364,6 +385,21 @@ class PSContext {
|
|||
// Secure mechanism for federated learning. Used in federated learning for now.
|
||||
std::string encrypt_type_;
|
||||
|
||||
// Top-k of SignDS mechanism.
|
||||
float sign_k_;
|
||||
|
||||
// Privacy budget epsilon of SignDS mechanism.
|
||||
float sign_eps_;
|
||||
|
||||
// The threshold for the expected ratio of topk dimensions in the output of SignDS mechanism.
|
||||
float sign_thr_ratio_;
|
||||
|
||||
// Global learning rate of SignDS mechanism.
|
||||
float sign_global_lr_;
|
||||
|
||||
// The number of output dimension of SignDS mechanism.
|
||||
int sign_dim_out_;
|
||||
|
||||
// Whether to enable ssl for network communication.
|
||||
bool enable_ssl_;
|
||||
// Password used to decode p12 file.
|
||||
|
|
|
@ -24,5 +24,6 @@ package com.mindspore.flclient;
|
|||
public enum EncryptLevel {
|
||||
PW_ENCRYPT,
|
||||
DP_ENCRYPT,
|
||||
SIGNDS,
|
||||
NOT_ENCRYPT
|
||||
}
|
|
@ -79,6 +79,11 @@ public class FLLiteClient {
|
|||
private String nextRequestTime;
|
||||
private Client client;
|
||||
private Map<String, float[]> oldFeatureMap;
|
||||
private float signK = 0.01f;
|
||||
private float signEps = 100;
|
||||
private float signThrRatio = 0.6f;
|
||||
private float signGlobalLr = 1f;
|
||||
private int signDimOut = 0;
|
||||
|
||||
/**
|
||||
* Defining a constructor of teh class FLLiteClient.
|
||||
|
@ -124,11 +129,11 @@ public class FLLiteClient {
|
|||
}
|
||||
switch (localFLParameter.getEncryptLevel()) {
|
||||
case PW_ENCRYPT:
|
||||
minSecretNum = cipherPublicParams.t();
|
||||
int primeLength = cipherPublicParams.primeLength();
|
||||
minSecretNum = cipherPublicParams.pwParams().t();
|
||||
int primeLength = cipherPublicParams.pwParams().primeLength();
|
||||
prime = new byte[primeLength];
|
||||
for (int i = 0; i < primeLength; i++) {
|
||||
prime[i] = (byte) cipherPublicParams.prime(i);
|
||||
prime[i] = (byte) cipherPublicParams.pwParams().prime(i);
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server: " + minSecretNum));
|
||||
if (minSecretNum <= 0) {
|
||||
|
@ -138,14 +143,26 @@ public class FLLiteClient {
|
|||
}
|
||||
break;
|
||||
case DP_ENCRYPT:
|
||||
dpEps = cipherPublicParams.dpEps();
|
||||
dpDelta = cipherPublicParams.dpDelta();
|
||||
dpNormClipFactor = cipherPublicParams.dpNormClip();
|
||||
dpEps = cipherPublicParams.dpParams().dpEps();
|
||||
dpDelta = cipherPublicParams.dpParams().dpDelta();
|
||||
dpNormClipFactor = cipherPublicParams.dpParams().dpNormClip();
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpEps> from server: " + dpEps));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpDelta> from server: " + dpDelta));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpNormClipFactor> from server: " +
|
||||
dpNormClipFactor));
|
||||
break;
|
||||
case SIGNDS:
|
||||
signK = cipherPublicParams.dsParams().signK();
|
||||
signEps = cipherPublicParams.dsParams().signEps();
|
||||
signThrRatio = cipherPublicParams.dsParams().signThrRatio();
|
||||
signGlobalLr = cipherPublicParams.dsParams().signGlobalLr();
|
||||
signDimOut = cipherPublicParams.dsParams().signDimOut();
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <signK> from server: " + signK));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <signEps> from server: " + signEps));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <signThrRatio> from server: " + signThrRatio));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <signGlobalLr> from server: " + signGlobalLr));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <SignDimOut> from server: " + signDimOut));
|
||||
break;
|
||||
default:
|
||||
LOGGER.info(Common.addTag("[startFLJob] NOT_ENCRYPT, do not set parameter for Encrypt"));
|
||||
}
|
||||
|
@ -418,7 +435,7 @@ public class FLLiteClient {
|
|||
}
|
||||
LOGGER.info(Common.addTag("[getModel] get response from server ok!"));
|
||||
} catch (IOException e) {
|
||||
failed("[getModel] un sloved error code: catch IOException: " + e.getMessage(), ResponseCode.RequestError);
|
||||
failed("[getModel] unsolved error code: catch IOException: " + e.getMessage(), ResponseCode.RequestError);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
@ -511,12 +528,24 @@ public class FLLiteClient {
|
|||
curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClipAdapt, oldFeatureMap);
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
if (curStatus != FLClientStatus.SUCCESS) {
|
||||
LOGGER.info(Common.addTag("---Differential privacy init failed---"));
|
||||
LOGGER.severe(Common.addTag("---Differential privacy init failed---"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[Encrypt] set parameters for DP_ENCRYPT!"));
|
||||
return FLClientStatus.SUCCESS;
|
||||
case SIGNDS:
|
||||
// get the feature map before train
|
||||
oldFeatureMap = getFeatureMap();
|
||||
curStatus = secureProtocol.setDSParameter(signK, signEps, signThrRatio, signGlobalLr, signDimOut, oldFeatureMap);
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
if (curStatus != FLClientStatus.SUCCESS) {
|
||||
LOGGER.severe(Common.addTag("---SignDS init failed---"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[Encrypt] set parameters for SignDS!"));
|
||||
return FLClientStatus.SUCCESS;
|
||||
case NOT_ENCRYPT:
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
LOGGER.info(Common.addTag("[Encrypt] don't mask model"));
|
||||
|
@ -552,6 +581,10 @@ public class FLLiteClient {
|
|||
LOGGER.info(Common.addTag("[Encrypt] haven't mask model"));
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
return FLClientStatus.SUCCESS;
|
||||
case SIGNDS:
|
||||
LOGGER.info(Common.addTag("[Encrypt] SIGNDS do not need unmasking"));
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
return FLClientStatus.SUCCESS;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[Encrypt] The encrypt level is error, not encrypt by default"));
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
|
|
|
@ -163,9 +163,10 @@ public class LocalFLParameter {
|
|||
}
|
||||
if ((!EncryptLevel.DP_ENCRYPT.toString().equals(encryptLevel)) &&
|
||||
(!EncryptLevel.NOT_ENCRYPT.toString().equals(encryptLevel)) &&
|
||||
(!EncryptLevel.SIGNDS.toString().equals(encryptLevel)) &&
|
||||
(!EncryptLevel.PW_ENCRYPT.toString().equals(encryptLevel))) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <encryptLevel> is " + encryptLevel + " ," +
|
||||
" it must be DP_ENCRYPT or NOT_ENCRYPT or PW_ENCRYPT, please check it before setting"));
|
||||
" it must be DP_ENCRYPT or NOT_ENCRYPT or PW_ENCRYPT or SIGNDS, please check it before setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.encryptLevel = encryptLevel;
|
||||
|
|
|
@ -23,6 +23,8 @@ import mindspore.schema.FeatureMap;
|
|||
import java.security.SecureRandom;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Map;
|
||||
import java.util.List;
|
||||
import java.util.HashMap;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
|
@ -46,6 +48,12 @@ public class SecureProtocol {
|
|||
private double dpNormClip;
|
||||
private ArrayList<String> updateFeatureName = new ArrayList<String>();
|
||||
private int retCode;
|
||||
private float signK;
|
||||
private float signEps;
|
||||
private float signThrRatio;
|
||||
private float signGlobalLr;
|
||||
private int signDimOut;
|
||||
|
||||
|
||||
/**
|
||||
* Obtain current status code in client.
|
||||
|
@ -102,6 +110,22 @@ public class SecureProtocol {
|
|||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Setting parameters for dimension select.
|
||||
*
|
||||
* @param map model weights.
|
||||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public FLClientStatus setDSParameter(float signK, float signEps, float signThrRatio, float signGlobalLr, int signDimOut, Map<String, float[]> map) {
|
||||
this.signK = signK;
|
||||
this.signEps = signEps;
|
||||
this.signThrRatio = signThrRatio;
|
||||
this.signGlobalLr = signGlobalLr;
|
||||
this.signDimOut = signDimOut;
|
||||
this.modelMap = map;
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtain the feature names that needed to be encrypted.
|
||||
*
|
||||
|
@ -367,6 +391,10 @@ public class SecureProtocol {
|
|||
}
|
||||
}
|
||||
updateL2Norm = Math.sqrt(updateL2Norm);
|
||||
if (updateL2Norm == 0) {
|
||||
LOGGER.severe(Common.addTag("[Encrypt] updateL2Norm is 0, please check"));
|
||||
return new int[0];
|
||||
}
|
||||
double clipFactor = Math.min(1.0, dpNormClip / updateL2Norm);
|
||||
|
||||
// clip and add noise
|
||||
|
@ -412,4 +440,282 @@ public class SecureProtocol {
|
|||
}
|
||||
return featuresMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* The number of combinations of n things taken k.
|
||||
*
|
||||
* @param n Number of things.
|
||||
* @param k Number of elements taken.
|
||||
* @return the total number of "n choose k" combinations.
|
||||
*/
|
||||
private static double comb(double n, double k) {
|
||||
boolean cond = (k <= n) && (n >= 0) && (k >= 0);
|
||||
double m = n + 1;
|
||||
if (!cond) {
|
||||
return 0;
|
||||
} else {
|
||||
double nTerm = Math.min(k, n - k);
|
||||
double res = 1;
|
||||
for (int i = 1; i <= nTerm; i++) {
|
||||
res *= (m - i);
|
||||
res /= i;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the number of possible combinations of output set given the number of topk dimensions.
|
||||
* c(k, v) * c(d-k, h-v)
|
||||
*
|
||||
* @param numInter the number of dimensions from topk set.
|
||||
* @param topkDim the size of top-k set.
|
||||
* @param inputDim total number of dimensions in the model.
|
||||
* @param outputDim the number of dimensions selected for constructing sparse local updates.
|
||||
* @return the number of possible combinations of output set.
|
||||
*/
|
||||
private static double countCombs(int numInter, int topkDim, int inputDim, int outputDim) {
|
||||
return comb(topkDim, numInter) * comb(inputDim - topkDim, outputDim - numInter);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the probability mass function of the number of topk dimensions in the output set.
|
||||
* v is the number of dimensions from topk set.
|
||||
*
|
||||
* @param thr threshold of the number of topk dimensions in the output set.
|
||||
* @param topkDim the size of top-k set.
|
||||
* @param inputDim total number of dimensions in the model.
|
||||
* @param outputDim the number of dimensions selected for constructing sparse local updates.
|
||||
* @param eps the privacy budget of SignDS alg.
|
||||
* @return the probability mass function.
|
||||
*/
|
||||
private static List<Double> calcPmf(int thr, int topkDim, int inputDim, int outputDim, float eps) {
|
||||
List<Double> pmf = new ArrayList<>();
|
||||
double newPmf;
|
||||
for (int v = 0; v <= outputDim; v++) {
|
||||
if (v < thr) {
|
||||
newPmf = countCombs(v, topkDim, inputDim, outputDim);
|
||||
} else {
|
||||
newPmf = countCombs(v, topkDim, inputDim, outputDim) * Math.exp(eps);
|
||||
}
|
||||
pmf.add(newPmf);
|
||||
}
|
||||
double pmfSum = 0;
|
||||
for (int i = 0; i < pmf.size(); i++) {
|
||||
pmfSum += pmf.get(i);
|
||||
}
|
||||
if (pmfSum == 0) {
|
||||
LOGGER.severe(Common.addTag("[SignDS] probability mass function is 0, please check"));
|
||||
return new ArrayList<>();
|
||||
}
|
||||
for (int i = 0; i < pmf.size(); i++) {
|
||||
pmf.set(i, pmf.get(i) / pmfSum);
|
||||
}
|
||||
return pmf;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the expected number of topk dimensions in the output set given outputDim.
|
||||
* The size of pmf is also outputDim.
|
||||
*
|
||||
* @param pmf probability mass function
|
||||
* @return the expectation of the topk dimensions in the output set.
|
||||
*/
|
||||
private static double calcExpectation(List<Double> pmf) {
|
||||
double sumExpectation = 0;
|
||||
for (int i = 0; i < pmf.size(); i++) {
|
||||
sumExpectation += (i * pmf.get(i));
|
||||
}
|
||||
return sumExpectation;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the optimum threshold for the number of topk dimension in the output set.
|
||||
* The optimum threshold is an integer among [1, outputDim], which has the largest
|
||||
* expectation value.
|
||||
*
|
||||
* @param topkDim the size of top-k set.
|
||||
* @param inputDim total number of dimensions in the model.
|
||||
* @param outputDim the number of dimensions selected for constructing sparse local updates.
|
||||
* @param eps the privacy budget of SignDS alg.
|
||||
* @return the optimum threshold.
|
||||
*/
|
||||
private static int calcOptThr(int topkDim, int inputDim, int outputDim, float eps) {
|
||||
double optExpect = 0;
|
||||
double optT = 0;
|
||||
for (int t = 1; t <= outputDim; t++) {
|
||||
double newExpect = calcExpectation(calcPmf(t, topkDim, inputDim, outputDim, eps));
|
||||
if (newExpect > optExpect) {
|
||||
optExpect = newExpect;
|
||||
optT = t;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return (int) Math.max(optT, 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Tool function for finding the optimum output dimension.
|
||||
* The main idea is to iteratively search for the largest output dimension while
|
||||
* ensuring the expected ratio of topk dimensions in the output set larger than
|
||||
* the target ratio.
|
||||
*
|
||||
* @param thrInterRatio threshold of the expected ratio of topk dimensions
|
||||
* @param topkDim the size of top-k set.
|
||||
* @param inputDim total number of dimensions in the model.
|
||||
* @param eps the privacy budget of SignDS alg.
|
||||
* @return the optimum output dimension.
|
||||
*/
|
||||
private static int findOptOutputDim(float thrInterRatio, int topkDim, int inputDim, float eps) {
|
||||
int outputDim = 1;
|
||||
while (true) {
|
||||
int thr = calcOptThr(topkDim, inputDim, outputDim, eps);
|
||||
double expectedRatio = calcExpectation(calcPmf(thr, topkDim, inputDim, outputDim, eps)) / outputDim;
|
||||
if (expectedRatio < thrInterRatio || Double.isNaN(expectedRatio)) {
|
||||
break;
|
||||
} else {
|
||||
outputDim += 1;
|
||||
}
|
||||
}
|
||||
return Math.max(1, (outputDim - 1));
|
||||
}
|
||||
|
||||
/**
|
||||
* Determine the number of dimensions to be sampled from the topk dimension set via
|
||||
* inverse sampling.
|
||||
* The main steps of the trick of inverse sampling include:
|
||||
* 1. Sample a random probability from the uniform distribution U(0, 1).
|
||||
* 2. Calculate the cumulative distribution of numInter, namely the number of
|
||||
* topk dimensions in the output set.
|
||||
* 3. Compare the cumulative distribution with the random probability and determine
|
||||
* the value of numInter.
|
||||
*
|
||||
* @param thrDim threshold of the number of topk dimensions in the output set.
|
||||
* @param denominator calculate denominator given the threshold.
|
||||
* @param topkDim the size of top-k set.
|
||||
* @param inputDim total number of dimensions in the model.
|
||||
* @param outputDim the number of dimensions selected for constructing sparse local updates.
|
||||
* @param eps the privacy budget of SignDS alg.
|
||||
* @return the number of dimensions to be sampled from the top-k dimension set.
|
||||
*/
|
||||
private static int countInters(int thrDim, double denominator, int topkDim, int inputDim, int outputDim, float eps) {
|
||||
SecureRandom secureRandom = new SecureRandom();
|
||||
double randomProb = secureRandom.nextDouble();
|
||||
int numInter = 0;
|
||||
double prob = countCombs(numInter, topkDim, inputDim, outputDim) / denominator;
|
||||
while (prob < randomProb) {
|
||||
numInter += 1;
|
||||
if (numInter < thrDim) {
|
||||
prob += countCombs(numInter, topkDim, inputDim, outputDim) / denominator;
|
||||
} else {
|
||||
prob += Math.exp(eps) * countCombs(numInter, topkDim, inputDim, outputDim) / denominator;
|
||||
}
|
||||
}
|
||||
return numInter;
|
||||
}
|
||||
|
||||
/**
|
||||
* SignDS model weights.
|
||||
*
|
||||
* @param builder the FlatBufferBuilder object used for serialization model weights.
|
||||
* @param trainDataSize tne size of train data set.
|
||||
* @return the serialized model weights after adding masks.
|
||||
*/
|
||||
|
||||
public int[] signDSModel(FlatBufferBuilder builder, int trainDataSize, Map<String, float[]> trainedMap) {
|
||||
Map<String, float[]> mapBeforeTrain = modelMap;
|
||||
int layerNum = updateFeatureName.size();
|
||||
int[] featuresMap = new int[layerNum];
|
||||
SecureRandom secureRandom = Common.getSecureRandom();
|
||||
boolean sign = secureRandom.nextBoolean();
|
||||
List<String> nonTopkKeyList = new ArrayList<>();
|
||||
List<String> topkKeyList = new ArrayList<>();
|
||||
Map<String, Float> allUpdateMap = new HashMap<>();
|
||||
for (int i = 0; i < layerNum; i++) {
|
||||
String key = updateFeatureName.get(i);
|
||||
float[] dataAfterTrain = trainedMap.get(key);
|
||||
float[] dataBeforeTrain = mapBeforeTrain.get(key);
|
||||
for (int j = 0; j < dataAfterTrain.length; j++) {
|
||||
float updateData = dataAfterTrain[j] - dataBeforeTrain[j];
|
||||
String ij = Integer.toString(i) + ',' + j;
|
||||
allUpdateMap.put(ij, updateData);
|
||||
}
|
||||
}
|
||||
int inputDim = allUpdateMap.size();
|
||||
int topkDim = (int) (signK * inputDim);
|
||||
if (signDimOut == 0) {
|
||||
signDimOut = findOptOutputDim(signThrRatio, topkDim, inputDim, signEps);
|
||||
}
|
||||
int thrDim = calcOptThr(topkDim, inputDim, signDimOut, signEps);
|
||||
double combLessInter = 0d;
|
||||
double combMoreInter = 0d;
|
||||
for (int i = 0; i < thrDim; i++) {
|
||||
combLessInter += countCombs(i, topkDim, inputDim, signDimOut);
|
||||
}
|
||||
for (int i = thrDim; i <= signDimOut; i++) {
|
||||
combMoreInter += countCombs(i, topkDim, inputDim, signDimOut);
|
||||
}
|
||||
double denominator = combLessInter + Math.exp(signEps) * combMoreInter;
|
||||
if (denominator == 0) {
|
||||
LOGGER.severe(Common.addTag("[SignDS] denominator is 0, please check"));
|
||||
return new int[0];
|
||||
}
|
||||
int numInter = countInters(thrDim, denominator, topkDim, inputDim, signDimOut, signEps);
|
||||
int numOuter = signDimOut - numInter;
|
||||
if (topkDim < numInter || signDimOut <= 0) {
|
||||
LOGGER.severe("[SignDS] topkDim or signDimOut is ERROR! please check");
|
||||
return new int[0];
|
||||
}
|
||||
|
||||
List<Map.Entry<String, Float>> allUpdateList = new ArrayList<>(allUpdateMap.entrySet());
|
||||
if (sign) {
|
||||
allUpdateList.sort((o1, o2) -> Float.compare(o2.getValue(), o1.getValue()));
|
||||
} else {
|
||||
allUpdateList.sort((o1, o2) -> Float.compare(o1.getValue(), o2.getValue()));
|
||||
}
|
||||
for (int i = 0; i < topkDim; i++) {
|
||||
topkKeyList.add(allUpdateList.get(i).getKey());
|
||||
}
|
||||
for (int i = topkDim; i < allUpdateList.size(); i++) {
|
||||
nonTopkKeyList.add(allUpdateList.get(i).getKey());
|
||||
}
|
||||
List<String> outputDimensionIJStringList = new ArrayList<>();
|
||||
for (int i = topkKeyList.size(); i > topkKeyList.size() - numInter; i--) {
|
||||
int randomIndex = secureRandom.nextInt(i);
|
||||
String randomChoiceTopkIJString = topkKeyList.get(randomIndex);
|
||||
topkKeyList.set(randomIndex, topkKeyList.get(i - 1));
|
||||
topkKeyList.set(i - 1, randomChoiceTopkIJString);
|
||||
outputDimensionIJStringList.add(randomChoiceTopkIJString);
|
||||
}
|
||||
for (int i = nonTopkKeyList.size(); i > nonTopkKeyList.size() - numOuter; i--) {
|
||||
int randomIndex = secureRandom.nextInt(i);
|
||||
String randomChoiceNonTopkIJString = nonTopkKeyList.get(randomIndex);
|
||||
nonTopkKeyList.set(randomIndex, nonTopkKeyList.get(i - 1));
|
||||
nonTopkKeyList.set(i - 1, randomChoiceNonTopkIJString);
|
||||
outputDimensionIJStringList.add(randomChoiceNonTopkIJString);
|
||||
}
|
||||
float signValue = sign ? 1f * signGlobalLr : -1f * signGlobalLr;
|
||||
for (String ijString : outputDimensionIJStringList) {
|
||||
String[] ij = ijString.split(",");
|
||||
int iKeyIndex = Integer.parseInt(ij[0]);
|
||||
int jDataIndex = Integer.parseInt(ij[1]);
|
||||
String key = updateFeatureName.get(iKeyIndex);
|
||||
float[] dataBeforeTrain = mapBeforeTrain.get(key);
|
||||
dataBeforeTrain[jDataIndex] += signValue;
|
||||
mapBeforeTrain.put(key, dataBeforeTrain);
|
||||
}
|
||||
for (int i = 0; i < layerNum; i++) {
|
||||
String key = updateFeatureName.get(i);
|
||||
float[] dataBeforeTrain = mapBeforeTrain.get(key);
|
||||
for (int j = 0; j < dataBeforeTrain.length; j++) {
|
||||
dataBeforeTrain[j] *= trainDataSize;
|
||||
}
|
||||
int featureName = builder.createString(key);
|
||||
int weight = FeatureMap.createDataVector(builder, dataBeforeTrain);
|
||||
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
|
||||
featuresMap[i] = featureMap;
|
||||
}
|
||||
return featuresMap;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -310,6 +310,18 @@ public class UpdateModel {
|
|||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP);
|
||||
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
|
||||
return this;
|
||||
case SIGNDS:
|
||||
int[] fmOffsetsSignDS = secureProtocol.signDSModel(builder, trainDataSize, trainedMap);
|
||||
if (fmOffsetsSignDS == null || fmOffsetsSignDS.length == 0) {
|
||||
LOGGER.severe("[Encrypt] the return fmOffsetsSignDS from <secureProtocol.signDSModel> is " +
|
||||
"null, please check");
|
||||
retCode = ResponseCode.RequestError;
|
||||
status = FLClientStatus.FAILED;
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsSignDS);
|
||||
LOGGER.info(Common.addTag("[Encrypt] SignDS mask model ok!"));
|
||||
return this;
|
||||
case NOT_ENCRYPT:
|
||||
default:
|
||||
int featureSize = updateFeatureName.size();
|
||||
|
|
|
@ -1059,11 +1059,18 @@ def set_fl_context(**kwargs):
|
|||
dp_norm_clip (float): A factor used for clipping model's weights for differential mechanism. Its value is
|
||||
suggested to be 0.5~2. Default: 1.0.
|
||||
encrypt_type (string): Secure schema for federated learning, which can be 'NOT_ENCRYPT', 'DP_ENCRYPT',
|
||||
'PW_ENCRYPT' or 'STABLE_PW_ENCRYPT'. If 'DP_ENCRYPT', differential privacy schema would be applied
|
||||
'PW_ENCRYPT', 'STABLE_PW_ENCRYPT' or 'SIGNDS'. If 'DP_ENCRYPT', differential privacy schema would be applied
|
||||
for clients and the privacy protection effect would be determined by dp_eps, dp_delta and dp_norm_clip
|
||||
as described above. If 'PW_ENCRYPT', pairwise secure aggregation would be applied to protect clients'
|
||||
model from stealing in cross-device scenario. If 'STABLE_PW_ENCRYPT', pairwise secure aggregation would
|
||||
be applied to protect clients' model from stealing in cross-silo scenario. Default: 'NOT_ENCRYPT'.
|
||||
be applied to protect clients' model from stealing in cross-silo scenario. If 'SIGNDS', SignDS schema would
|
||||
be applied for clients. Default: 'NOT_ENCRYPT'.
|
||||
sign_k (float): SignDS: Top-k ratio, namely the number of top-k dimensions divided by the total number of
|
||||
dimensions. Default: 0.01.
|
||||
sign_eps (float): SignDS: Privacy budget. Default: 100.
|
||||
sign_thr_ratio (float): SignDS: Threshold of the expected topk dimension. Default: 0.6.
|
||||
sign_global_lr (float): SignDS: The constant value assigned to the selected dimension. Default: 1.
|
||||
sign_dim_out (int): SignDS: Number of output dimensions. Default: 0.
|
||||
config_file_path (string): Configuration file path used by recovery. Default: ''.
|
||||
scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202.
|
||||
enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: False.
|
||||
|
|
|
@ -71,7 +71,12 @@ _set_ps_context_func_map = {
|
|||
"dp_delta": ps_context().set_dp_delta,
|
||||
"dp_norm_clip": ps_context().set_dp_norm_clip,
|
||||
"encrypt_type": ps_context().set_encrypt_type,
|
||||
"http_url_prefix": ps_context().set_http_url_prefix
|
||||
"http_url_prefix": ps_context().set_http_url_prefix,
|
||||
"sign_k": ps_context().set_sign_k,
|
||||
"sign_eps": ps_context().set_sign_eps,
|
||||
"sign_thr_ratio": ps_context().set_sign_thr_ratio,
|
||||
"sign_global_lr": ps_context().set_sign_global_lr,
|
||||
"sign_dim_out": ps_context().set_sign_dim_out
|
||||
}
|
||||
|
||||
_get_ps_context_func_map = {
|
||||
|
@ -112,7 +117,12 @@ _get_ps_context_func_map = {
|
|||
"server_password": ps_context().server_password,
|
||||
"scheduler_manage_port": ps_context().scheduler_manage_port,
|
||||
"config_file_path": ps_context().config_file_path,
|
||||
"http_url_prefix": ps_context().http_url_prefix
|
||||
"http_url_prefix": ps_context().http_url_prefix,
|
||||
"sign_k": ps_context().sign_k,
|
||||
"sign_eps": ps_context().sign_eps,
|
||||
"sign_thr_ratio": ps_context().sign_thr_ratio,
|
||||
"sign_global_lr": ps_context().sign_global_lr,
|
||||
"sign_dim_out": ps_context().sign_dim_out
|
||||
}
|
||||
|
||||
_check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port",
|
||||
|
|
|
@ -16,15 +16,32 @@
|
|||
|
||||
namespace mindspore.schema;
|
||||
|
||||
table CipherPublicParams {
|
||||
table PWParams {
|
||||
t:int;
|
||||
p:[ubyte];
|
||||
g:int;
|
||||
prime:[ubyte];
|
||||
}
|
||||
|
||||
table DPParams {
|
||||
dp_eps:float;
|
||||
dp_delta:float;
|
||||
dp_norm_clip:float;
|
||||
}
|
||||
|
||||
table DSParams {
|
||||
sign_k:float;
|
||||
sign_eps:float;
|
||||
sign_thr_ratio:float;
|
||||
sign_global_lr:float;
|
||||
sign_dim_out:int;
|
||||
}
|
||||
|
||||
table CipherPublicParams {
|
||||
encrypt_type:string;
|
||||
pw_params:PWParams;
|
||||
dp_params:DPParams;
|
||||
ds_params:DSParams;
|
||||
}
|
||||
|
||||
table ClientPublicKeys {
|
||||
|
|
|
@ -76,6 +76,12 @@ def parse_args():
|
|||
parser.add_argument("--root_second_ca_path", type=str, default="")
|
||||
parser.add_argument("--equip_crl_path", type=str, default="")
|
||||
parser.add_argument("--replay_attack_time_diff", type=int, default=600000)
|
||||
# parameters for 'SIGNDS'
|
||||
parser.add_argument("--sign_k", type=float, default=0.01)
|
||||
parser.add_argument("--sign_eps", type=float, default=100)
|
||||
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
|
||||
parser.add_argument("--sign_global_lr", type=float, default=0.1)
|
||||
parser.add_argument("--sign_dim_out", type=int, default=0)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -118,6 +124,11 @@ def server_train(args):
|
|||
root_second_ca_path = args.root_second_ca_path
|
||||
equip_crl_path = args.equip_crl_path
|
||||
replay_attack_time_diff = args.replay_attack_time_diff
|
||||
sign_k = args.sign_k
|
||||
sign_eps = args.sign_eps
|
||||
sign_thr_ratio = args.sign_thr_ratio
|
||||
sign_global_lr = args.sign_global_lr
|
||||
sign_dim_out = args.sign_dim_out
|
||||
|
||||
# Replace some parameters with federated learning parameters.
|
||||
train_cfg.max_global_epoch = fl_iteration_num
|
||||
|
@ -155,7 +166,12 @@ def server_train(args):
|
|||
"root_first_ca_path": root_first_ca_path,
|
||||
"root_second_ca_path": root_second_ca_path,
|
||||
"equip_crl_path": equip_crl_path,
|
||||
"replay_attack_time_diff": replay_attack_time_diff
|
||||
"replay_attack_time_diff": replay_attack_time_diff,
|
||||
"sign_k": sign_k,
|
||||
"sign_eps": sign_eps,
|
||||
"sign_thr_ratio": sign_thr_ratio,
|
||||
"sign_global_lr": sign_global_lr,
|
||||
"sign_dim_out": sign_dim_out
|
||||
}
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
|
|
|
@ -58,6 +58,12 @@ parser.add_argument("--root_first_ca_path", type=str, default="")
|
|||
parser.add_argument("--root_second_ca_path", type=str, default="")
|
||||
parser.add_argument("--equip_crl_path", type=str, default="")
|
||||
parser.add_argument("--replay_attack_time_diff", type=int, default=600000)
|
||||
# parameters for 'SIGNDS'
|
||||
parser.add_argument("--sign_k", type=float, default=0.01)
|
||||
parser.add_argument("--sign_eps", type=float, default=100)
|
||||
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
|
||||
parser.add_argument("--sign_global_lr", type=float, default=0.1)
|
||||
parser.add_argument("--sign_dim_out", type=int, default=0)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -93,6 +99,11 @@ root_first_ca_path = args.root_first_ca_path
|
|||
root_second_ca_path = args.root_second_ca_path
|
||||
equip_crl_path = args.equip_crl_path
|
||||
replay_attack_time_diff = args.replay_attack_time_diff
|
||||
sign_k = args.sign_k
|
||||
sign_eps = args.sign_eps
|
||||
sign_thr_ratio = args.sign_thr_ratio
|
||||
sign_global_lr = args.sign_global_lr
|
||||
sign_dim_out = args.sign_dim_out
|
||||
|
||||
if local_server_num == -1:
|
||||
local_server_num = server_num
|
||||
|
@ -139,6 +150,11 @@ for i in range(local_server_num):
|
|||
cmd_server += " --root_second_ca_path=" + str(root_second_ca_path)
|
||||
cmd_server += " --equip_crl_path=" + str(equip_crl_path)
|
||||
cmd_server += " --replay_attack_time_diff=" + str(replay_attack_time_diff)
|
||||
cmd_server += " --sign_k=" + str(sign_k)
|
||||
cmd_server += " --sign_eps=" + str(sign_eps)
|
||||
cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio)
|
||||
cmd_server += " --sign_global_lr=" + str(sign_global_lr)
|
||||
cmd_server += " --sign_dim_out=" + str(sign_dim_out)
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
import time
|
||||
|
|
|
@ -49,6 +49,12 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
|
|||
parser.add_argument("--client_password", type=str, default="")
|
||||
parser.add_argument("--server_password", type=str, default="")
|
||||
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
|
||||
# parameters for 'SIGNDS'
|
||||
parser.add_argument("--sign_k", type=float, default=0.01)
|
||||
parser.add_argument("--sign_eps", type=float, default=100)
|
||||
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
|
||||
parser.add_argument("--sign_global_lr", type=float, default=0.1)
|
||||
parser.add_argument("--sign_dim_out", type=int, default=0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, _ = parser.parse_known_args()
|
||||
|
@ -80,6 +86,11 @@ if __name__ == "__main__":
|
|||
client_password = args.client_password
|
||||
server_password = args.server_password
|
||||
enable_ssl = args.enable_ssl
|
||||
sign_k = args.sign_k
|
||||
sign_eps = args.sign_eps
|
||||
sign_thr_ratio = args.sign_thr_ratio
|
||||
sign_global_lr = args.sign_global_lr
|
||||
sign_dim_out = args.sign_dim_out
|
||||
|
||||
if local_server_num == -1:
|
||||
local_server_num = server_num
|
||||
|
@ -121,6 +132,11 @@ if __name__ == "__main__":
|
|||
cmd_server += " --server_password=" + str(server_password)
|
||||
cmd_server += " --enable_ssl=" + str(enable_ssl)
|
||||
cmd_server += " --encrypt_type=" + str(encrypt_type)
|
||||
cmd_server += " --sign_k=" + str(sign_k)
|
||||
cmd_server += " --sign_eps=" + str(sign_eps)
|
||||
cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio)
|
||||
cmd_server += " --sign_global_lr=" + str(sign_global_lr)
|
||||
cmd_server += " --sign_dim_out=" + str(sign_dim_out)
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
import time
|
||||
|
|
|
@ -56,6 +56,12 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
|
|||
parser.add_argument("--client_password", type=str, default="")
|
||||
parser.add_argument("--server_password", type=str, default="")
|
||||
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
|
||||
# parameters for 'SIGNDS'
|
||||
parser.add_argument("--sign_k", type=float, default=0.01)
|
||||
parser.add_argument("--sign_eps", type=float, default=100)
|
||||
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
|
||||
parser.add_argument("--sign_global_lr", type=float, default=0.1)
|
||||
parser.add_argument("--sign_dim_out", type=int, default=0)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -87,6 +93,11 @@ encrypt_type = args.encrypt_type
|
|||
client_password = args.client_password
|
||||
server_password = args.server_password
|
||||
enable_ssl = args.enable_ssl
|
||||
sign_k = args.sign_k
|
||||
sign_eps = args.sign_eps
|
||||
sign_thr_ratio = args.sign_thr_ratio
|
||||
sign_global_lr = args.sign_global_lr
|
||||
sign_dim_out = args.sign_dim_out
|
||||
|
||||
ctx = {
|
||||
"enable_fl": True,
|
||||
|
@ -117,7 +128,12 @@ ctx = {
|
|||
"encrypt_type": encrypt_type,
|
||||
"client_password": client_password,
|
||||
"server_password": server_password,
|
||||
"enable_ssl": enable_ssl
|
||||
"enable_ssl": enable_ssl,
|
||||
"sign_k": sign_k,
|
||||
"sign_eps": sign_eps,
|
||||
"sign_thr_ratio": sign_thr_ratio,
|
||||
"sign_global_lr": sign_global_lr,
|
||||
"sign_dim_out": sign_dim_out
|
||||
}
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
|
|
|
@ -57,6 +57,12 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
|
|||
parser.add_argument("--client_password", type=str, default="")
|
||||
parser.add_argument("--server_password", type=str, default="")
|
||||
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
|
||||
# parameters for 'SIGNDS'
|
||||
parser.add_argument("--sign_k", type=float, default=0.01)
|
||||
parser.add_argument("--sign_eps", type=float, default=100)
|
||||
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
|
||||
parser.add_argument("--sign_global_lr", type=float, default=0.1)
|
||||
parser.add_argument("--sign_dim_out", type=int, default=0)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -92,6 +98,11 @@ root_first_ca_path = args.root_first_ca_path
|
|||
root_second_ca_path = args.root_second_ca_path
|
||||
equip_crl_path = args.equip_crl_path
|
||||
replay_attack_time_diff = args.replay_attack_time_diff
|
||||
sign_k = args.sign_k
|
||||
sign_eps = args.sign_eps
|
||||
sign_thr_ratio = args.sign_thr_ratio
|
||||
sign_global_lr = args.sign_global_lr
|
||||
sign_dim_out = args.sign_dim_out
|
||||
|
||||
if local_server_num == -1:
|
||||
local_server_num = server_num
|
||||
|
@ -138,6 +149,11 @@ for i in range(local_server_num):
|
|||
cmd_server += " --root_second_ca_path=" + str(root_second_ca_path)
|
||||
cmd_server += " --equip_crl_path=" + str(equip_crl_path)
|
||||
cmd_server += " --replay_attack_time_diff=" + str(replay_attack_time_diff)
|
||||
cmd_server += " --sign_k=" + str(sign_k)
|
||||
cmd_server += " --sign_eps=" + str(sign_eps)
|
||||
cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio)
|
||||
cmd_server += " --sign_global_lr=" + str(sign_global_lr)
|
||||
cmd_server += " --sign_dim_out=" + str(sign_dim_out)
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
import time
|
||||
|
|
|
@ -65,6 +65,12 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
|
|||
parser.add_argument("--client_password", type=str, default="")
|
||||
parser.add_argument("--server_password", type=str, default="")
|
||||
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
|
||||
# parameters for 'SIGNDS'
|
||||
parser.add_argument("--sign_k", type=float, default=0.01)
|
||||
parser.add_argument("--sign_eps", type=float, default=100)
|
||||
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
|
||||
parser.add_argument("--sign_global_lr", type=float, default=0.1)
|
||||
parser.add_argument("--sign_dim_out", type=int, default=0)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -102,6 +108,11 @@ root_first_ca_path = args.root_first_ca_path
|
|||
root_second_ca_path = args.root_second_ca_path
|
||||
equip_crl_path = args.equip_crl_path
|
||||
replay_attack_time_diff = args.replay_attack_time_diff
|
||||
sign_k = args.sign_k
|
||||
sign_eps = args.sign_eps
|
||||
sign_thr_ratio = args.sign_thr_ratio
|
||||
sign_global_lr = args.sign_global_lr
|
||||
sign_dim_out = args.sign_dim_out
|
||||
|
||||
ctx = {
|
||||
"enable_fl": True,
|
||||
|
@ -138,7 +149,12 @@ ctx = {
|
|||
"encrypt_type": encrypt_type,
|
||||
"client_password": client_password,
|
||||
"server_password": server_password,
|
||||
"enable_ssl": enable_ssl
|
||||
"enable_ssl": enable_ssl,
|
||||
"sign_k": sign_k,
|
||||
"sign_eps": sign_eps,
|
||||
"sign_thr_ratio": sign_thr_ratio,
|
||||
"sign_global_lr": sign_global_lr,
|
||||
"sign_dim_out": sign_dim_out
|
||||
}
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
|
|
|
@ -55,6 +55,12 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
|
|||
parser.add_argument("--client_password", type=str, default="")
|
||||
parser.add_argument("--server_password", type=str, default="")
|
||||
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
|
||||
# parameters for 'SIGNDS'
|
||||
parser.add_argument("--sign_k", type=float, default=0.01)
|
||||
parser.add_argument("--sign_eps", type=float, default=100)
|
||||
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
|
||||
parser.add_argument("--sign_global_lr", type=float, default=0.1)
|
||||
parser.add_argument("--sign_dim_out", type=int, default=0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, _ = parser.parse_known_args()
|
||||
|
@ -91,6 +97,11 @@ if __name__ == "__main__":
|
|||
client_password = args.client_password
|
||||
server_password = args.server_password
|
||||
enable_ssl = args.enable_ssl
|
||||
sign_k = args.sign_k
|
||||
sign_eps = args.sign_eps
|
||||
sign_thr_ratio = args.sign_thr_ratio
|
||||
sign_global_lr = args.sign_global_lr
|
||||
sign_dim_out = args.sign_dim_out
|
||||
|
||||
if local_server_num == -1:
|
||||
local_server_num = server_num
|
||||
|
@ -137,6 +148,11 @@ if __name__ == "__main__":
|
|||
cmd_server += " --replay_attack_time_diff=" + str(replay_attack_time_diff)
|
||||
cmd_server += " --enable_ssl=" + str(enable_ssl)
|
||||
cmd_server += " --encrypt_type=" + str(encrypt_type)
|
||||
cmd_server += " --sign_k=" + str(sign_k)
|
||||
cmd_server += " --sign_eps=" + str(sign_eps)
|
||||
cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio)
|
||||
cmd_server += " --sign_global_lr=" + str(sign_global_lr)
|
||||
cmd_server += " --sign_dim_out=" + str(sign_dim_out)
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
import time
|
||||
|
|
|
@ -64,6 +64,12 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
|
|||
parser.add_argument("--client_password", type=str, default="")
|
||||
parser.add_argument("--server_password", type=str, default="")
|
||||
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
|
||||
# parameters for 'SIGNDS'
|
||||
parser.add_argument("--sign_k", type=float, default=0.01)
|
||||
parser.add_argument("--sign_eps", type=float, default=100)
|
||||
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
|
||||
parser.add_argument("--sign_global_lr", type=float, default=0.1)
|
||||
parser.add_argument("--sign_dim_out", type=int, default=0)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -100,6 +106,11 @@ replay_attack_time_diff = args.replay_attack_time_diff
|
|||
client_password = args.client_password
|
||||
server_password = args.server_password
|
||||
enable_ssl = args.enable_ssl
|
||||
sign_k = args.sign_k
|
||||
sign_eps = args.sign_eps
|
||||
sign_thr_ratio = args.sign_thr_ratio
|
||||
sign_global_lr = args.sign_global_lr
|
||||
sign_dim_out = args.sign_dim_out
|
||||
|
||||
ctx = {
|
||||
"enable_fl": True,
|
||||
|
@ -135,7 +146,12 @@ ctx = {
|
|||
"encrypt_type": encrypt_type,
|
||||
"client_password": client_password,
|
||||
"server_password": server_password,
|
||||
"enable_ssl": enable_ssl
|
||||
"enable_ssl": enable_ssl,
|
||||
"sign_k": sign_k,
|
||||
"sign_eps": sign_eps,
|
||||
"sign_thr_ratio": sign_thr_ratio,
|
||||
"sign_global_lr": sign_global_lr,
|
||||
"sign_dim_out": sign_dim_out
|
||||
}
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
|
|
Loading…
Reference in New Issue