This commit is contained in:
emmmmtang 2022-01-14 18:52:59 +08:00
parent c72a1146fb
commit 09c938a584
23 changed files with 692 additions and 37 deletions

View File

@ -55,11 +55,21 @@ bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size
publicparam_.dp_delta = param.dp_delta;
publicparam_.dp_norm_clip = param.dp_norm_clip;
publicparam_.encrypt_type = param.encrypt_type;
publicparam_.sign_k = param.sign_k;
publicparam_.sign_eps = param.sign_eps;
publicparam_.sign_thr_ratio = param.sign_thr_ratio;
publicparam_.sign_global_lr = param.sign_global_lr;
publicparam_.sign_dim_out = param.sign_dim_out;
if (param.encrypt_type == mindspore::ps::kDPEncryptType) {
MS_LOG(INFO) << "DP parameters init, dp_eps: " << param.dp_eps;
MS_LOG(INFO) << "DP parameters init, dp_delta: " << param.dp_delta;
MS_LOG(INFO) << "DP parameters init, dp_norm_clip: " << param.dp_norm_clip;
MS_LOG(INFO) << "DP parameters init, dp_eps: " << param.dp_eps << ", dp_delta: " << param.dp_delta
<< ", dp_norm_clip: " << param.dp_norm_clip;
}
if (param.encrypt_type == mindspore::ps::kDSEncryptType) {
MS_LOG(INFO) << "Sign parameters init, sign_k: " << param.sign_k << ", sign_eps: " << param.sign_eps
<< ", sign_thr_ratio: " << param.sign_thr_ratio << ", sign_global_lr: " << param.sign_global_lr
<< ", sign_dim_out: " << param.sign_dim_out;
}
if (param.encrypt_type == mindspore::ps::kPWEncryptType) {

View File

@ -68,6 +68,11 @@ struct CipherPublicPara {
float dp_delta;
float dp_norm_clip;
string encrypt_type;
float sign_k;
float sign_eps;
float sign_thr_ratio;
float sign_global_lr;
int sign_dim_out;
};
class CipherMetaStorage {

View File

@ -330,9 +330,17 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
float dp_delta = param->dp_delta;
float dp_norm_clip = param->dp_norm_clip;
auto encrypt_type = fbb->CreateString(ps::PSContext::instance()->encrypt_type());
float sign_k = param->sign_k;
float sign_eps = param->sign_eps;
float sign_thr_ratio = param->sign_thr_ratio;
float sign_global_lr = param->sign_global_lr;
int sign_dim_out = param->sign_dim_out;
auto pw_params = schema::CreatePWParams(*fbb.get(), t, p, g, prime);
auto dp_params = schema::CreateDPParams(*fbb.get(), dp_eps, dp_delta, dp_norm_clip);
auto ds_params = schema::CreateDSParams(*fbb.get(), sign_k, sign_eps, sign_thr_ratio, sign_global_lr, sign_dim_out);
auto cipher_public_params =
schema::CreateCipherPublicParams(*fbb.get(), t, p, g, prime, dp_eps, dp_delta, dp_norm_clip, encrypt_type);
schema::CreateCipherPublicParams(*fbb.get(), encrypt_type, pw_params, dp_params, ds_params);
#endif
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));

View File

@ -256,6 +256,11 @@ void Server::InitCipher() {
float dp_delta = ps::PSContext::instance()->dp_delta();
float dp_norm_clip = ps::PSContext::instance()->dp_norm_clip();
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
float sign_k = ps::PSContext::instance()->sign_k();
float sign_eps = ps::PSContext::instance()->sign_eps();
float sign_thr_ratio = ps::PSContext::instance()->sign_thr_ratio();
float sign_global_lr = ps::PSContext::instance()->sign_global_lr();
int sign_dim_out = ps::PSContext::instance()->sign_dim_out();
mindspore::armour::CipherPublicPara param;
param.g = cipher_g;
@ -268,6 +273,11 @@ void Server::InitCipher() {
param.dp_eps = dp_eps;
param.dp_norm_clip = dp_norm_clip;
param.encrypt_type = encrypt_type;
param.sign_k = sign_k;
param.sign_eps = sign_eps;
param.sign_thr_ratio = sign_thr_ratio;
param.sign_global_lr = sign_global_lr;
param.sign_dim_out = sign_dim_out;
BIGNUM *prim = BN_new();
if (prim == NULL) {

View File

@ -468,12 +468,18 @@ PYBIND11_MODULE(_c_expression, m) {
"Set configuration files required by the communication layer.")
.def("config_file_path", &PSContext::config_file_path,
"Get configuration files required by the communication layer.")
.def("set_dp_eps", &PSContext::set_dp_eps, "Set dp epsilon for federated learning secure aggregation.")
.def("set_dp_delta", &PSContext::set_dp_delta, "Set dp delta for federated learning secure aggregation.")
.def("set_dp_norm_clip", &PSContext::set_dp_norm_clip,
"Set dp norm clip for federated learning secure aggregation.")
.def("set_encrypt_type", &PSContext::set_encrypt_type,
"Set encrypt type for federated learning secure aggregation.")
.def("set_sign_k", &PSContext::set_sign_k, "Set sign k for federated learning SignDS.")
.def("sign_k", &PSContext::sign_k, "Get sign k for federated learning SignDS.")
.def("set_sign_eps", &PSContext::set_sign_eps, "Set sign eps for federated learning SignDS.")
.def("sign_eps", &PSContext::sign_eps, "Get sign eps for federated learning SignDS.")
.def("set_sign_thr_ratio", &PSContext::set_sign_thr_ratio, "Set sign thr ratio for federated learning SignDS.")
.def("sign_thr_ratio", &PSContext::sign_thr_ratio, "Get sign thr ratio for federated learning SignDS.")
.def("set_sign_global_lr", &PSContext::set_sign_global_lr, "Set sign global lr for federated learning SignDS.")
.def("sign_global_lr", &PSContext::sign_global_lr, "Get sign global lr for federated learning SignDS.")
.def("set_sign_dim_out", &PSContext::set_sign_dim_out, "Set sign dim out for federated learning SignDS.")
.def("sign_dim_out", &PSContext::sign_dim_out, "Get sign dim out for federated learning SignDS.")
.def("set_http_url_prefix", &PSContext::set_http_url_prefix, "Set http url prefix for http communication.")
.def("http_url_prefix", &PSContext::http_url_prefix, "http url prefix for http communication.");

View File

@ -220,12 +220,14 @@ const std::string &PSContext::server_mode() const { return server_mode_; }
void PSContext::set_encrypt_type(const std::string &encrypt_type) {
if (encrypt_type != kNotEncryptType && encrypt_type != kDPEncryptType && encrypt_type != kPWEncryptType &&
encrypt_type != kStablePWEncryptType) {
MS_LOG(EXCEPTION) << encrypt_type << " is invalid. Encrypt type must be " << kNotEncryptType << " or "
<< kDPEncryptType << " or " << kPWEncryptType << " or " << kStablePWEncryptType;
return;
encrypt_type != kStablePWEncryptType && encrypt_type != kDSEncryptType) {
MS_LOG(WARNING) << encrypt_type << " is invalid. Encrypt type must be " << kNotEncryptType << " or "
<< kDPEncryptType << " or " << kPWEncryptType << " or " << kStablePWEncryptType << " or "
<< kDSEncryptType << ", DP scheme is used by default.";
encrypt_type_ = kDPEncryptType;
} else {
encrypt_type_ = encrypt_type;
}
encrypt_type_ = encrypt_type;
}
const std::string &PSContext::encrypt_type() const { return encrypt_type_; }
@ -234,8 +236,9 @@ void PSContext::set_dp_eps(float dp_eps) {
if (dp_eps > 0) {
dp_eps_ = dp_eps;
} else {
MS_LOG(EXCEPTION) << dp_eps << " is invalid, dp_eps must be larger than 0.";
return;
MS_LOG(WARNING) << dp_eps << " is invalid, dp_eps must be larger than 0, 50 is used by default.";
float dp_eps_default = 50;
dp_eps_ = dp_eps_default;
}
}
@ -245,8 +248,9 @@ void PSContext::set_dp_delta(float dp_delta) {
if (dp_delta > 0 && dp_delta < 1) {
dp_delta_ = dp_delta;
} else {
MS_LOG(EXCEPTION) << dp_delta << " is invalid, dp_delta must be in range of (0, 1).";
return;
MS_LOG(WARNING) << dp_delta << " is invalid, dp_delta must be in range of (0, 1), 0.01 is used by default.";
float dp_delta_default = 0.01;
dp_delta_ = dp_delta_default;
}
}
float PSContext::dp_delta() const { return dp_delta_; }
@ -255,12 +259,73 @@ void PSContext::set_dp_norm_clip(float dp_norm_clip) {
if (dp_norm_clip > 0) {
dp_norm_clip_ = dp_norm_clip;
} else {
MS_LOG(EXCEPTION) << dp_norm_clip << " is invalid, dp_norm_clip must be larger than 0.";
return;
MS_LOG(WARNING) << dp_norm_clip << " is invalid, dp_norm_clip must be larger than 0, 1 is used by default.";
float dp_norm_clip_default = 1;
dp_norm_clip_ = dp_norm_clip_default;
}
}
float PSContext::dp_norm_clip() const { return dp_norm_clip_; }
void PSContext::set_sign_k(float sign_k) {
float sign_k_upper = 0.25;
if (sign_k > 0 && sign_k <= sign_k_upper) {
sign_k_ = sign_k;
} else {
MS_LOG(WARNING) << sign_k << " is invalid, sign_k must be in range of (0, 0.25], 0.01 is used by default.";
float sign_k_default = 0.01;
sign_k_ = sign_k_default;
}
}
float PSContext::sign_k() const { return sign_k_; }
void PSContext::set_sign_eps(float sign_eps) {
float sign_eps_upper = 100;
if (sign_eps > 0 && sign_eps <= sign_eps_upper) {
sign_eps_ = sign_eps;
} else {
MS_LOG(WARNING) << sign_eps << " is invalid, sign_eps must be in range of (0, 100], 100 is used by default.";
float sign_eps_default = 100;
sign_eps_ = sign_eps_default;
}
}
float PSContext::sign_eps() const { return sign_eps_; }
void PSContext::set_sign_thr_ratio(float sign_thr_ratio) {
float sign_thr_ratio_bound = 0.5;
if (sign_thr_ratio >= sign_thr_ratio_bound && sign_thr_ratio <= 1) {
sign_thr_ratio_ = sign_thr_ratio;
} else {
MS_LOG(WARNING) << sign_thr_ratio
<< " is invalid, sign_thr_ratio must be in range of [0.5, 1], 0.6 is used by default.";
float sign_thr_ratio_default = 0.6;
sign_thr_ratio_ = sign_thr_ratio_default;
}
}
float PSContext::sign_thr_ratio() const { return sign_thr_ratio_; }
void PSContext::set_sign_global_lr(float sign_global_lr) {
if (sign_global_lr > 0) {
sign_global_lr_ = sign_global_lr;
} else {
MS_LOG(WARNING) << sign_global_lr << " is invalid, sign_global_lr must be larger than 0, 1 is used by default.";
float sign_global_lr_default = 1;
sign_global_lr_ = sign_global_lr_default;
}
}
float PSContext::sign_global_lr() const { return sign_global_lr_; }
void PSContext::set_sign_dim_out(int sign_dim_out) {
int sign_dim_out_upper = 50;
if (sign_dim_out >= 0 && sign_dim_out <= sign_dim_out_upper) {
sign_dim_out_ = sign_dim_out;
} else {
MS_LOG(WARNING) << sign_dim_out << " is invalid, sign_dim_out must be in range of [0, 50], 0 is used by default.";
float sign_dim_out_default = 0;
sign_dim_out_ = sign_dim_out_default;
}
}
int PSContext::sign_dim_out() const { return sign_dim_out_; }
void PSContext::set_ms_role(const std::string &role) {
if (role != kEnvRoleOfWorker && role != kEnvRoleOfServer && role != kEnvRoleOfScheduler) {
MS_LOG(EXCEPTION) << "ms_role " << role << " is invalid.";

View File

@ -39,6 +39,7 @@ constexpr char kDPEncryptType[] = "DP_ENCRYPT";
constexpr char kPWEncryptType[] = "PW_ENCRYPT";
constexpr char kStablePWEncryptType[] = "STABLE_PW_ENCRYPT";
constexpr char kNotEncryptType[] = "NOT_ENCRYPT";
constexpr char kDSEncryptType[] = "SIGNDS";
// Use binary data to represent federated learning server's context so that we can judge which round resets the
// iteration. From right to left, each bit stands for:
@ -173,6 +174,21 @@ class PSContext {
void set_dp_norm_clip(float dp_norm_clip);
float dp_norm_clip() const;
void set_sign_k(float sign_k);
float sign_k() const;
void set_sign_eps(float sign_eps);
float sign_eps() const;
void set_sign_thr_ratio(float sign_thr_ratio);
float sign_thr_ratio() const;
void set_sign_global_lr(float sign_global_lr);
float sign_global_lr() const;
void set_sign_dim_out(int sign_dim_out);
int sign_dim_out() const;
core::ClusterConfig &cluster_config();
void set_root_first_ca_path(const std::string &root_first_ca_path);
@ -254,6 +270,11 @@ class PSContext {
dp_delta_(0.01),
dp_norm_clip_(1.0),
encrypt_type_(kNotEncryptType),
sign_k_(0.01),
sign_eps_(100),
sign_thr_ratio_(0.6),
sign_global_lr_(1),
sign_dim_out_(0),
enable_ssl_(false),
client_password_(""),
server_password_(""),
@ -364,6 +385,21 @@ class PSContext {
// Secure mechanism for federated learning. Used in federated learning for now.
std::string encrypt_type_;
// Top-k of SignDS mechanism.
float sign_k_;
// Privacy budget epsilon of SignDS mechanism.
float sign_eps_;
// The threshold for the expected ratio of topk dimensions in the output of SignDS mechanism.
float sign_thr_ratio_;
// Global learning rate of SignDS mechanism.
float sign_global_lr_;
// The number of output dimension of SignDS mechanism.
int sign_dim_out_;
// Whether to enable ssl for network communication.
bool enable_ssl_;
// Password used to decode p12 file.

View File

@ -24,5 +24,6 @@ package com.mindspore.flclient;
public enum EncryptLevel {
PW_ENCRYPT,
DP_ENCRYPT,
SIGNDS,
NOT_ENCRYPT
}

View File

@ -79,6 +79,11 @@ public class FLLiteClient {
private String nextRequestTime;
private Client client;
private Map<String, float[]> oldFeatureMap;
private float signK = 0.01f;
private float signEps = 100;
private float signThrRatio = 0.6f;
private float signGlobalLr = 1f;
private int signDimOut = 0;
/**
* Defining a constructor of teh class FLLiteClient.
@ -124,11 +129,11 @@ public class FLLiteClient {
}
switch (localFLParameter.getEncryptLevel()) {
case PW_ENCRYPT:
minSecretNum = cipherPublicParams.t();
int primeLength = cipherPublicParams.primeLength();
minSecretNum = cipherPublicParams.pwParams().t();
int primeLength = cipherPublicParams.pwParams().primeLength();
prime = new byte[primeLength];
for (int i = 0; i < primeLength; i++) {
prime[i] = (byte) cipherPublicParams.prime(i);
prime[i] = (byte) cipherPublicParams.pwParams().prime(i);
}
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server: " + minSecretNum));
if (minSecretNum <= 0) {
@ -138,14 +143,26 @@ public class FLLiteClient {
}
break;
case DP_ENCRYPT:
dpEps = cipherPublicParams.dpEps();
dpDelta = cipherPublicParams.dpDelta();
dpNormClipFactor = cipherPublicParams.dpNormClip();
dpEps = cipherPublicParams.dpParams().dpEps();
dpDelta = cipherPublicParams.dpParams().dpDelta();
dpNormClipFactor = cipherPublicParams.dpParams().dpNormClip();
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpEps> from server: " + dpEps));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpDelta> from server: " + dpDelta));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpNormClipFactor> from server: " +
dpNormClipFactor));
break;
case SIGNDS:
signK = cipherPublicParams.dsParams().signK();
signEps = cipherPublicParams.dsParams().signEps();
signThrRatio = cipherPublicParams.dsParams().signThrRatio();
signGlobalLr = cipherPublicParams.dsParams().signGlobalLr();
signDimOut = cipherPublicParams.dsParams().signDimOut();
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <signK> from server: " + signK));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <signEps> from server: " + signEps));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <signThrRatio> from server: " + signThrRatio));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <signGlobalLr> from server: " + signGlobalLr));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <SignDimOut> from server: " + signDimOut));
break;
default:
LOGGER.info(Common.addTag("[startFLJob] NOT_ENCRYPT, do not set parameter for Encrypt"));
}
@ -418,7 +435,7 @@ public class FLLiteClient {
}
LOGGER.info(Common.addTag("[getModel] get response from server ok!"));
} catch (IOException e) {
failed("[getModel] un sloved error code: catch IOException: " + e.getMessage(), ResponseCode.RequestError);
failed("[getModel] unsolved error code: catch IOException: " + e.getMessage(), ResponseCode.RequestError);
}
return status;
}
@ -511,12 +528,24 @@ public class FLLiteClient {
curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClipAdapt, oldFeatureMap);
retCode = ResponseCode.SUCCEED;
if (curStatus != FLClientStatus.SUCCESS) {
LOGGER.info(Common.addTag("---Differential privacy init failed---"));
LOGGER.severe(Common.addTag("---Differential privacy init failed---"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[Encrypt] set parameters for DP_ENCRYPT!"));
return FLClientStatus.SUCCESS;
case SIGNDS:
// get the feature map before train
oldFeatureMap = getFeatureMap();
curStatus = secureProtocol.setDSParameter(signK, signEps, signThrRatio, signGlobalLr, signDimOut, oldFeatureMap);
retCode = ResponseCode.SUCCEED;
if (curStatus != FLClientStatus.SUCCESS) {
LOGGER.severe(Common.addTag("---SignDS init failed---"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[Encrypt] set parameters for SignDS!"));
return FLClientStatus.SUCCESS;
case NOT_ENCRYPT:
retCode = ResponseCode.SUCCEED;
LOGGER.info(Common.addTag("[Encrypt] don't mask model"));
@ -552,6 +581,10 @@ public class FLLiteClient {
LOGGER.info(Common.addTag("[Encrypt] haven't mask model"));
retCode = ResponseCode.SUCCEED;
return FLClientStatus.SUCCESS;
case SIGNDS:
LOGGER.info(Common.addTag("[Encrypt] SIGNDS do not need unmasking"));
retCode = ResponseCode.SUCCEED;
return FLClientStatus.SUCCESS;
default:
LOGGER.severe(Common.addTag("[Encrypt] The encrypt level is error, not encrypt by default"));
retCode = ResponseCode.SUCCEED;

View File

@ -163,9 +163,10 @@ public class LocalFLParameter {
}
if ((!EncryptLevel.DP_ENCRYPT.toString().equals(encryptLevel)) &&
(!EncryptLevel.NOT_ENCRYPT.toString().equals(encryptLevel)) &&
(!EncryptLevel.SIGNDS.toString().equals(encryptLevel)) &&
(!EncryptLevel.PW_ENCRYPT.toString().equals(encryptLevel))) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <encryptLevel> is " + encryptLevel + " ," +
" it must be DP_ENCRYPT or NOT_ENCRYPT or PW_ENCRYPT, please check it before setting"));
" it must be DP_ENCRYPT or NOT_ENCRYPT or PW_ENCRYPT or SIGNDS, please check it before setting"));
throw new IllegalArgumentException();
}
this.encryptLevel = encryptLevel;

View File

@ -23,6 +23,8 @@ import mindspore.schema.FeatureMap;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Map;
import java.util.List;
import java.util.HashMap;
import java.util.logging.Logger;
/**
@ -46,6 +48,12 @@ public class SecureProtocol {
private double dpNormClip;
private ArrayList<String> updateFeatureName = new ArrayList<String>();
private int retCode;
private float signK;
private float signEps;
private float signThrRatio;
private float signGlobalLr;
private int signDimOut;
/**
* Obtain current status code in client.
@ -102,6 +110,22 @@ public class SecureProtocol {
return FLClientStatus.SUCCESS;
}
/**
* Setting parameters for dimension select.
*
* @param map model weights.
* @return the status code corresponding to the response message.
*/
public FLClientStatus setDSParameter(float signK, float signEps, float signThrRatio, float signGlobalLr, int signDimOut, Map<String, float[]> map) {
this.signK = signK;
this.signEps = signEps;
this.signThrRatio = signThrRatio;
this.signGlobalLr = signGlobalLr;
this.signDimOut = signDimOut;
this.modelMap = map;
return FLClientStatus.SUCCESS;
}
/**
* Obtain the feature names that needed to be encrypted.
*
@ -367,6 +391,10 @@ public class SecureProtocol {
}
}
updateL2Norm = Math.sqrt(updateL2Norm);
if (updateL2Norm == 0) {
LOGGER.severe(Common.addTag("[Encrypt] updateL2Norm is 0, please check"));
return new int[0];
}
double clipFactor = Math.min(1.0, dpNormClip / updateL2Norm);
// clip and add noise
@ -412,4 +440,282 @@ public class SecureProtocol {
}
return featuresMap;
}
/**
* The number of combinations of n things taken k.
*
* @param n Number of things.
* @param k Number of elements taken.
* @return the total number of "n choose k" combinations.
*/
private static double comb(double n, double k) {
boolean cond = (k <= n) && (n >= 0) && (k >= 0);
double m = n + 1;
if (!cond) {
return 0;
} else {
double nTerm = Math.min(k, n - k);
double res = 1;
for (int i = 1; i <= nTerm; i++) {
res *= (m - i);
res /= i;
}
return res;
}
}
/**
* Calculate the number of possible combinations of output set given the number of topk dimensions.
* c(k, v) * c(d-k, h-v)
*
* @param numInter the number of dimensions from topk set.
* @param topkDim the size of top-k set.
* @param inputDim total number of dimensions in the model.
* @param outputDim the number of dimensions selected for constructing sparse local updates.
* @return the number of possible combinations of output set.
*/
private static double countCombs(int numInter, int topkDim, int inputDim, int outputDim) {
return comb(topkDim, numInter) * comb(inputDim - topkDim, outputDim - numInter);
}
/**
* Calculate the probability mass function of the number of topk dimensions in the output set.
* v is the number of dimensions from topk set.
*
* @param thr threshold of the number of topk dimensions in the output set.
* @param topkDim the size of top-k set.
* @param inputDim total number of dimensions in the model.
* @param outputDim the number of dimensions selected for constructing sparse local updates.
* @param eps the privacy budget of SignDS alg.
* @return the probability mass function.
*/
private static List<Double> calcPmf(int thr, int topkDim, int inputDim, int outputDim, float eps) {
List<Double> pmf = new ArrayList<>();
double newPmf;
for (int v = 0; v <= outputDim; v++) {
if (v < thr) {
newPmf = countCombs(v, topkDim, inputDim, outputDim);
} else {
newPmf = countCombs(v, topkDim, inputDim, outputDim) * Math.exp(eps);
}
pmf.add(newPmf);
}
double pmfSum = 0;
for (int i = 0; i < pmf.size(); i++) {
pmfSum += pmf.get(i);
}
if (pmfSum == 0) {
LOGGER.severe(Common.addTag("[SignDS] probability mass function is 0, please check"));
return new ArrayList<>();
}
for (int i = 0; i < pmf.size(); i++) {
pmf.set(i, pmf.get(i) / pmfSum);
}
return pmf;
}
/**
* Calculate the expected number of topk dimensions in the output set given outputDim.
* The size of pmf is also outputDim.
*
* @param pmf probability mass function
* @return the expectation of the topk dimensions in the output set.
*/
private static double calcExpectation(List<Double> pmf) {
double sumExpectation = 0;
for (int i = 0; i < pmf.size(); i++) {
sumExpectation += (i * pmf.get(i));
}
return sumExpectation;
}
/**
* Calculate the optimum threshold for the number of topk dimension in the output set.
* The optimum threshold is an integer among [1, outputDim], which has the largest
* expectation value.
*
* @param topkDim the size of top-k set.
* @param inputDim total number of dimensions in the model.
* @param outputDim the number of dimensions selected for constructing sparse local updates.
* @param eps the privacy budget of SignDS alg.
* @return the optimum threshold.
*/
private static int calcOptThr(int topkDim, int inputDim, int outputDim, float eps) {
double optExpect = 0;
double optT = 0;
for (int t = 1; t <= outputDim; t++) {
double newExpect = calcExpectation(calcPmf(t, topkDim, inputDim, outputDim, eps));
if (newExpect > optExpect) {
optExpect = newExpect;
optT = t;
} else {
break;
}
}
return (int) Math.max(optT, 1);
}
/**
* Tool function for finding the optimum output dimension.
* The main idea is to iteratively search for the largest output dimension while
* ensuring the expected ratio of topk dimensions in the output set larger than
* the target ratio.
*
* @param thrInterRatio threshold of the expected ratio of topk dimensions
* @param topkDim the size of top-k set.
* @param inputDim total number of dimensions in the model.
* @param eps the privacy budget of SignDS alg.
* @return the optimum output dimension.
*/
private static int findOptOutputDim(float thrInterRatio, int topkDim, int inputDim, float eps) {
int outputDim = 1;
while (true) {
int thr = calcOptThr(topkDim, inputDim, outputDim, eps);
double expectedRatio = calcExpectation(calcPmf(thr, topkDim, inputDim, outputDim, eps)) / outputDim;
if (expectedRatio < thrInterRatio || Double.isNaN(expectedRatio)) {
break;
} else {
outputDim += 1;
}
}
return Math.max(1, (outputDim - 1));
}
/**
* Determine the number of dimensions to be sampled from the topk dimension set via
* inverse sampling.
* The main steps of the trick of inverse sampling include:
* 1. Sample a random probability from the uniform distribution U(0, 1).
* 2. Calculate the cumulative distribution of numInter, namely the number of
* topk dimensions in the output set.
* 3. Compare the cumulative distribution with the random probability and determine
* the value of numInter.
*
* @param thrDim threshold of the number of topk dimensions in the output set.
* @param denominator calculate denominator given the threshold.
* @param topkDim the size of top-k set.
* @param inputDim total number of dimensions in the model.
* @param outputDim the number of dimensions selected for constructing sparse local updates.
* @param eps the privacy budget of SignDS alg.
* @return the number of dimensions to be sampled from the top-k dimension set.
*/
private static int countInters(int thrDim, double denominator, int topkDim, int inputDim, int outputDim, float eps) {
SecureRandom secureRandom = new SecureRandom();
double randomProb = secureRandom.nextDouble();
int numInter = 0;
double prob = countCombs(numInter, topkDim, inputDim, outputDim) / denominator;
while (prob < randomProb) {
numInter += 1;
if (numInter < thrDim) {
prob += countCombs(numInter, topkDim, inputDim, outputDim) / denominator;
} else {
prob += Math.exp(eps) * countCombs(numInter, topkDim, inputDim, outputDim) / denominator;
}
}
return numInter;
}
/**
* SignDS model weights.
*
* @param builder the FlatBufferBuilder object used for serialization model weights.
* @param trainDataSize tne size of train data set.
* @return the serialized model weights after adding masks.
*/
public int[] signDSModel(FlatBufferBuilder builder, int trainDataSize, Map<String, float[]> trainedMap) {
Map<String, float[]> mapBeforeTrain = modelMap;
int layerNum = updateFeatureName.size();
int[] featuresMap = new int[layerNum];
SecureRandom secureRandom = Common.getSecureRandom();
boolean sign = secureRandom.nextBoolean();
List<String> nonTopkKeyList = new ArrayList<>();
List<String> topkKeyList = new ArrayList<>();
Map<String, Float> allUpdateMap = new HashMap<>();
for (int i = 0; i < layerNum; i++) {
String key = updateFeatureName.get(i);
float[] dataAfterTrain = trainedMap.get(key);
float[] dataBeforeTrain = mapBeforeTrain.get(key);
for (int j = 0; j < dataAfterTrain.length; j++) {
float updateData = dataAfterTrain[j] - dataBeforeTrain[j];
String ij = Integer.toString(i) + ',' + j;
allUpdateMap.put(ij, updateData);
}
}
int inputDim = allUpdateMap.size();
int topkDim = (int) (signK * inputDim);
if (signDimOut == 0) {
signDimOut = findOptOutputDim(signThrRatio, topkDim, inputDim, signEps);
}
int thrDim = calcOptThr(topkDim, inputDim, signDimOut, signEps);
double combLessInter = 0d;
double combMoreInter = 0d;
for (int i = 0; i < thrDim; i++) {
combLessInter += countCombs(i, topkDim, inputDim, signDimOut);
}
for (int i = thrDim; i <= signDimOut; i++) {
combMoreInter += countCombs(i, topkDim, inputDim, signDimOut);
}
double denominator = combLessInter + Math.exp(signEps) * combMoreInter;
if (denominator == 0) {
LOGGER.severe(Common.addTag("[SignDS] denominator is 0, please check"));
return new int[0];
}
int numInter = countInters(thrDim, denominator, topkDim, inputDim, signDimOut, signEps);
int numOuter = signDimOut - numInter;
if (topkDim < numInter || signDimOut <= 0) {
LOGGER.severe("[SignDS] topkDim or signDimOut is ERROR! please check");
return new int[0];
}
List<Map.Entry<String, Float>> allUpdateList = new ArrayList<>(allUpdateMap.entrySet());
if (sign) {
allUpdateList.sort((o1, o2) -> Float.compare(o2.getValue(), o1.getValue()));
} else {
allUpdateList.sort((o1, o2) -> Float.compare(o1.getValue(), o2.getValue()));
}
for (int i = 0; i < topkDim; i++) {
topkKeyList.add(allUpdateList.get(i).getKey());
}
for (int i = topkDim; i < allUpdateList.size(); i++) {
nonTopkKeyList.add(allUpdateList.get(i).getKey());
}
List<String> outputDimensionIJStringList = new ArrayList<>();
for (int i = topkKeyList.size(); i > topkKeyList.size() - numInter; i--) {
int randomIndex = secureRandom.nextInt(i);
String randomChoiceTopkIJString = topkKeyList.get(randomIndex);
topkKeyList.set(randomIndex, topkKeyList.get(i - 1));
topkKeyList.set(i - 1, randomChoiceTopkIJString);
outputDimensionIJStringList.add(randomChoiceTopkIJString);
}
for (int i = nonTopkKeyList.size(); i > nonTopkKeyList.size() - numOuter; i--) {
int randomIndex = secureRandom.nextInt(i);
String randomChoiceNonTopkIJString = nonTopkKeyList.get(randomIndex);
nonTopkKeyList.set(randomIndex, nonTopkKeyList.get(i - 1));
nonTopkKeyList.set(i - 1, randomChoiceNonTopkIJString);
outputDimensionIJStringList.add(randomChoiceNonTopkIJString);
}
float signValue = sign ? 1f * signGlobalLr : -1f * signGlobalLr;
for (String ijString : outputDimensionIJStringList) {
String[] ij = ijString.split(",");
int iKeyIndex = Integer.parseInt(ij[0]);
int jDataIndex = Integer.parseInt(ij[1]);
String key = updateFeatureName.get(iKeyIndex);
float[] dataBeforeTrain = mapBeforeTrain.get(key);
dataBeforeTrain[jDataIndex] += signValue;
mapBeforeTrain.put(key, dataBeforeTrain);
}
for (int i = 0; i < layerNum; i++) {
String key = updateFeatureName.get(i);
float[] dataBeforeTrain = mapBeforeTrain.get(key);
for (int j = 0; j < dataBeforeTrain.length; j++) {
dataBeforeTrain[j] *= trainDataSize;
}
int featureName = builder.createString(key);
int weight = FeatureMap.createDataVector(builder, dataBeforeTrain);
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
featuresMap[i] = featureMap;
}
return featuresMap;
}
}

View File

@ -310,6 +310,18 @@ public class UpdateModel {
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP);
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
return this;
case SIGNDS:
int[] fmOffsetsSignDS = secureProtocol.signDSModel(builder, trainDataSize, trainedMap);
if (fmOffsetsSignDS == null || fmOffsetsSignDS.length == 0) {
LOGGER.severe("[Encrypt] the return fmOffsetsSignDS from <secureProtocol.signDSModel> is " +
"null, please check");
retCode = ResponseCode.RequestError;
status = FLClientStatus.FAILED;
throw new IllegalArgumentException();
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsSignDS);
LOGGER.info(Common.addTag("[Encrypt] SignDS mask model ok!"));
return this;
case NOT_ENCRYPT:
default:
int featureSize = updateFeatureName.size();

View File

@ -1059,11 +1059,18 @@ def set_fl_context(**kwargs):
dp_norm_clip (float): A factor used for clipping model's weights for differential mechanism. Its value is
suggested to be 0.5~2. Default: 1.0.
encrypt_type (string): Secure schema for federated learning, which can be 'NOT_ENCRYPT', 'DP_ENCRYPT',
'PW_ENCRYPT' or 'STABLE_PW_ENCRYPT'. If 'DP_ENCRYPT', differential privacy schema would be applied
'PW_ENCRYPT', 'STABLE_PW_ENCRYPT' or 'SIGNDS'. If 'DP_ENCRYPT', differential privacy schema would be applied
for clients and the privacy protection effect would be determined by dp_eps, dp_delta and dp_norm_clip
as described above. If 'PW_ENCRYPT', pairwise secure aggregation would be applied to protect clients'
model from stealing in cross-device scenario. If 'STABLE_PW_ENCRYPT', pairwise secure aggregation would
be applied to protect clients' model from stealing in cross-silo scenario. Default: 'NOT_ENCRYPT'.
be applied to protect clients' model from stealing in cross-silo scenario. If 'SIGNDS', SignDS schema would
be applied for clients. Default: 'NOT_ENCRYPT'.
sign_k (float): SignDS: Top-k ratio, namely the number of top-k dimensions divided by the total number of
dimensions. Default: 0.01.
sign_eps (float): SignDS: Privacy budget. Default: 100.
sign_thr_ratio (float): SignDS: Threshold of the expected topk dimension. Default: 0.6.
sign_global_lr (float): SignDS: The constant value assigned to the selected dimension. Default: 1.
sign_dim_out (int): SignDS: Number of output dimensions. Default: 0.
config_file_path (string): Configuration file path used by recovery. Default: ''.
scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202.
enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: False.

View File

@ -71,7 +71,12 @@ _set_ps_context_func_map = {
"dp_delta": ps_context().set_dp_delta,
"dp_norm_clip": ps_context().set_dp_norm_clip,
"encrypt_type": ps_context().set_encrypt_type,
"http_url_prefix": ps_context().set_http_url_prefix
"http_url_prefix": ps_context().set_http_url_prefix,
"sign_k": ps_context().set_sign_k,
"sign_eps": ps_context().set_sign_eps,
"sign_thr_ratio": ps_context().set_sign_thr_ratio,
"sign_global_lr": ps_context().set_sign_global_lr,
"sign_dim_out": ps_context().set_sign_dim_out
}
_get_ps_context_func_map = {
@ -112,7 +117,12 @@ _get_ps_context_func_map = {
"server_password": ps_context().server_password,
"scheduler_manage_port": ps_context().scheduler_manage_port,
"config_file_path": ps_context().config_file_path,
"http_url_prefix": ps_context().http_url_prefix
"http_url_prefix": ps_context().http_url_prefix,
"sign_k": ps_context().sign_k,
"sign_eps": ps_context().sign_eps,
"sign_thr_ratio": ps_context().sign_thr_ratio,
"sign_global_lr": ps_context().sign_global_lr,
"sign_dim_out": ps_context().sign_dim_out
}
_check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port",

View File

@ -16,15 +16,32 @@
namespace mindspore.schema;
table CipherPublicParams {
table PWParams {
t:int;
p:[ubyte];
g:int;
prime:[ubyte];
}
table DPParams {
dp_eps:float;
dp_delta:float;
dp_norm_clip:float;
}
table DSParams {
sign_k:float;
sign_eps:float;
sign_thr_ratio:float;
sign_global_lr:float;
sign_dim_out:int;
}
table CipherPublicParams {
encrypt_type:string;
pw_params:PWParams;
dp_params:DPParams;
ds_params:DSParams;
}
table ClientPublicKeys {

View File

@ -76,6 +76,12 @@ def parse_args():
parser.add_argument("--root_second_ca_path", type=str, default="")
parser.add_argument("--equip_crl_path", type=str, default="")
parser.add_argument("--replay_attack_time_diff", type=int, default=600000)
# parameters for 'SIGNDS'
parser.add_argument("--sign_k", type=float, default=0.01)
parser.add_argument("--sign_eps", type=float, default=100)
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
parser.add_argument("--sign_global_lr", type=float, default=0.1)
parser.add_argument("--sign_dim_out", type=int, default=0)
return parser.parse_args()
@ -118,6 +124,11 @@ def server_train(args):
root_second_ca_path = args.root_second_ca_path
equip_crl_path = args.equip_crl_path
replay_attack_time_diff = args.replay_attack_time_diff
sign_k = args.sign_k
sign_eps = args.sign_eps
sign_thr_ratio = args.sign_thr_ratio
sign_global_lr = args.sign_global_lr
sign_dim_out = args.sign_dim_out
# Replace some parameters with federated learning parameters.
train_cfg.max_global_epoch = fl_iteration_num
@ -155,7 +166,12 @@ def server_train(args):
"root_first_ca_path": root_first_ca_path,
"root_second_ca_path": root_second_ca_path,
"equip_crl_path": equip_crl_path,
"replay_attack_time_diff": replay_attack_time_diff
"replay_attack_time_diff": replay_attack_time_diff,
"sign_k": sign_k,
"sign_eps": sign_eps,
"sign_thr_ratio": sign_thr_ratio,
"sign_global_lr": sign_global_lr,
"sign_dim_out": sign_dim_out
}
if not os.path.exists(output_dir):

View File

@ -58,6 +58,12 @@ parser.add_argument("--root_first_ca_path", type=str, default="")
parser.add_argument("--root_second_ca_path", type=str, default="")
parser.add_argument("--equip_crl_path", type=str, default="")
parser.add_argument("--replay_attack_time_diff", type=int, default=600000)
# parameters for 'SIGNDS'
parser.add_argument("--sign_k", type=float, default=0.01)
parser.add_argument("--sign_eps", type=float, default=100)
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
parser.add_argument("--sign_global_lr", type=float, default=0.1)
parser.add_argument("--sign_dim_out", type=int, default=0)
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -93,6 +99,11 @@ root_first_ca_path = args.root_first_ca_path
root_second_ca_path = args.root_second_ca_path
equip_crl_path = args.equip_crl_path
replay_attack_time_diff = args.replay_attack_time_diff
sign_k = args.sign_k
sign_eps = args.sign_eps
sign_thr_ratio = args.sign_thr_ratio
sign_global_lr = args.sign_global_lr
sign_dim_out = args.sign_dim_out
if local_server_num == -1:
local_server_num = server_num
@ -139,6 +150,11 @@ for i in range(local_server_num):
cmd_server += " --root_second_ca_path=" + str(root_second_ca_path)
cmd_server += " --equip_crl_path=" + str(equip_crl_path)
cmd_server += " --replay_attack_time_diff=" + str(replay_attack_time_diff)
cmd_server += " --sign_k=" + str(sign_k)
cmd_server += " --sign_eps=" + str(sign_eps)
cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio)
cmd_server += " --sign_global_lr=" + str(sign_global_lr)
cmd_server += " --sign_dim_out=" + str(sign_dim_out)
cmd_server += " > server.log 2>&1 &"
import time

View File

@ -49,6 +49,12 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
parser.add_argument("--client_password", type=str, default="")
parser.add_argument("--server_password", type=str, default="")
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
# parameters for 'SIGNDS'
parser.add_argument("--sign_k", type=float, default=0.01)
parser.add_argument("--sign_eps", type=float, default=100)
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
parser.add_argument("--sign_global_lr", type=float, default=0.1)
parser.add_argument("--sign_dim_out", type=int, default=0)
if __name__ == "__main__":
args, _ = parser.parse_known_args()
@ -80,6 +86,11 @@ if __name__ == "__main__":
client_password = args.client_password
server_password = args.server_password
enable_ssl = args.enable_ssl
sign_k = args.sign_k
sign_eps = args.sign_eps
sign_thr_ratio = args.sign_thr_ratio
sign_global_lr = args.sign_global_lr
sign_dim_out = args.sign_dim_out
if local_server_num == -1:
local_server_num = server_num
@ -121,6 +132,11 @@ if __name__ == "__main__":
cmd_server += " --server_password=" + str(server_password)
cmd_server += " --enable_ssl=" + str(enable_ssl)
cmd_server += " --encrypt_type=" + str(encrypt_type)
cmd_server += " --sign_k=" + str(sign_k)
cmd_server += " --sign_eps=" + str(sign_eps)
cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio)
cmd_server += " --sign_global_lr=" + str(sign_global_lr)
cmd_server += " --sign_dim_out=" + str(sign_dim_out)
cmd_server += " > server.log 2>&1 &"
import time

View File

@ -56,6 +56,12 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
parser.add_argument("--client_password", type=str, default="")
parser.add_argument("--server_password", type=str, default="")
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
# parameters for 'SIGNDS'
parser.add_argument("--sign_k", type=float, default=0.01)
parser.add_argument("--sign_eps", type=float, default=100)
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
parser.add_argument("--sign_global_lr", type=float, default=0.1)
parser.add_argument("--sign_dim_out", type=int, default=0)
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -87,6 +93,11 @@ encrypt_type = args.encrypt_type
client_password = args.client_password
server_password = args.server_password
enable_ssl = args.enable_ssl
sign_k = args.sign_k
sign_eps = args.sign_eps
sign_thr_ratio = args.sign_thr_ratio
sign_global_lr = args.sign_global_lr
sign_dim_out = args.sign_dim_out
ctx = {
"enable_fl": True,
@ -117,7 +128,12 @@ ctx = {
"encrypt_type": encrypt_type,
"client_password": client_password,
"server_password": server_password,
"enable_ssl": enable_ssl
"enable_ssl": enable_ssl,
"sign_k": sign_k,
"sign_eps": sign_eps,
"sign_thr_ratio": sign_thr_ratio,
"sign_global_lr": sign_global_lr,
"sign_dim_out": sign_dim_out
}
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)

View File

@ -57,6 +57,12 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
parser.add_argument("--client_password", type=str, default="")
parser.add_argument("--server_password", type=str, default="")
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
# parameters for 'SIGNDS'
parser.add_argument("--sign_k", type=float, default=0.01)
parser.add_argument("--sign_eps", type=float, default=100)
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
parser.add_argument("--sign_global_lr", type=float, default=0.1)
parser.add_argument("--sign_dim_out", type=int, default=0)
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -92,6 +98,11 @@ root_first_ca_path = args.root_first_ca_path
root_second_ca_path = args.root_second_ca_path
equip_crl_path = args.equip_crl_path
replay_attack_time_diff = args.replay_attack_time_diff
sign_k = args.sign_k
sign_eps = args.sign_eps
sign_thr_ratio = args.sign_thr_ratio
sign_global_lr = args.sign_global_lr
sign_dim_out = args.sign_dim_out
if local_server_num == -1:
local_server_num = server_num
@ -138,6 +149,11 @@ for i in range(local_server_num):
cmd_server += " --root_second_ca_path=" + str(root_second_ca_path)
cmd_server += " --equip_crl_path=" + str(equip_crl_path)
cmd_server += " --replay_attack_time_diff=" + str(replay_attack_time_diff)
cmd_server += " --sign_k=" + str(sign_k)
cmd_server += " --sign_eps=" + str(sign_eps)
cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio)
cmd_server += " --sign_global_lr=" + str(sign_global_lr)
cmd_server += " --sign_dim_out=" + str(sign_dim_out)
cmd_server += " > server.log 2>&1 &"
import time

View File

@ -65,6 +65,12 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
parser.add_argument("--client_password", type=str, default="")
parser.add_argument("--server_password", type=str, default="")
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
# parameters for 'SIGNDS'
parser.add_argument("--sign_k", type=float, default=0.01)
parser.add_argument("--sign_eps", type=float, default=100)
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
parser.add_argument("--sign_global_lr", type=float, default=0.1)
parser.add_argument("--sign_dim_out", type=int, default=0)
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -102,6 +108,11 @@ root_first_ca_path = args.root_first_ca_path
root_second_ca_path = args.root_second_ca_path
equip_crl_path = args.equip_crl_path
replay_attack_time_diff = args.replay_attack_time_diff
sign_k = args.sign_k
sign_eps = args.sign_eps
sign_thr_ratio = args.sign_thr_ratio
sign_global_lr = args.sign_global_lr
sign_dim_out = args.sign_dim_out
ctx = {
"enable_fl": True,
@ -138,7 +149,12 @@ ctx = {
"encrypt_type": encrypt_type,
"client_password": client_password,
"server_password": server_password,
"enable_ssl": enable_ssl
"enable_ssl": enable_ssl,
"sign_k": sign_k,
"sign_eps": sign_eps,
"sign_thr_ratio": sign_thr_ratio,
"sign_global_lr": sign_global_lr,
"sign_dim_out": sign_dim_out
}
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)

View File

@ -55,6 +55,12 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
parser.add_argument("--client_password", type=str, default="")
parser.add_argument("--server_password", type=str, default="")
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
# parameters for 'SIGNDS'
parser.add_argument("--sign_k", type=float, default=0.01)
parser.add_argument("--sign_eps", type=float, default=100)
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
parser.add_argument("--sign_global_lr", type=float, default=0.1)
parser.add_argument("--sign_dim_out", type=int, default=0)
if __name__ == "__main__":
args, _ = parser.parse_known_args()
@ -91,6 +97,11 @@ if __name__ == "__main__":
client_password = args.client_password
server_password = args.server_password
enable_ssl = args.enable_ssl
sign_k = args.sign_k
sign_eps = args.sign_eps
sign_thr_ratio = args.sign_thr_ratio
sign_global_lr = args.sign_global_lr
sign_dim_out = args.sign_dim_out
if local_server_num == -1:
local_server_num = server_num
@ -137,6 +148,11 @@ if __name__ == "__main__":
cmd_server += " --replay_attack_time_diff=" + str(replay_attack_time_diff)
cmd_server += " --enable_ssl=" + str(enable_ssl)
cmd_server += " --encrypt_type=" + str(encrypt_type)
cmd_server += " --sign_k=" + str(sign_k)
cmd_server += " --sign_eps=" + str(sign_eps)
cmd_server += " --sign_thr_ratio=" + str(sign_thr_ratio)
cmd_server += " --sign_global_lr=" + str(sign_global_lr)
cmd_server += " --sign_dim_out=" + str(sign_dim_out)
cmd_server += " > server.log 2>&1 &"
import time

View File

@ -64,6 +64,12 @@ parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
parser.add_argument("--client_password", type=str, default="")
parser.add_argument("--server_password", type=str, default="")
parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False)
# parameters for 'SIGNDS'
parser.add_argument("--sign_k", type=float, default=0.01)
parser.add_argument("--sign_eps", type=float, default=100)
parser.add_argument("--sign_thr_ratio", type=float, default=0.6)
parser.add_argument("--sign_global_lr", type=float, default=0.1)
parser.add_argument("--sign_dim_out", type=int, default=0)
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -100,6 +106,11 @@ replay_attack_time_diff = args.replay_attack_time_diff
client_password = args.client_password
server_password = args.server_password
enable_ssl = args.enable_ssl
sign_k = args.sign_k
sign_eps = args.sign_eps
sign_thr_ratio = args.sign_thr_ratio
sign_global_lr = args.sign_global_lr
sign_dim_out = args.sign_dim_out
ctx = {
"enable_fl": True,
@ -135,7 +146,12 @@ ctx = {
"encrypt_type": encrypt_type,
"client_password": client_password,
"server_password": server_password,
"enable_ssl": enable_ssl
"enable_ssl": enable_ssl,
"sign_k": sign_k,
"sign_eps": sign_eps,
"sign_thr_ratio": sign_thr_ratio,
"sign_global_lr": sign_global_lr,
"sign_dim_out": sign_dim_out
}
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)