forked from mindspore-Ecosystem/mindspore
Add secure parameters for mindspore federated learning.
This commit is contained in:
parent
0de3e841bf
commit
ebc71d3306
|
@ -397,7 +397,13 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_config_file_path", &PSContext::set_config_file_path,
|
||||
"Set configuration files required by the communication layer.")
|
||||
.def("config_file_path", &PSContext::config_file_path,
|
||||
"Get configuration files required by the communication layer.");
|
||||
"Get configuration files required by the communication layer.")
|
||||
.def("set_dp_eps", &PSContext::set_dp_eps, "Set dp epsilon for federated learning secure aggregation.")
|
||||
.def("set_dp_delta", &PSContext::set_dp_delta, "Set dp delta for federated learning secure aggregation.")
|
||||
.def("set_dp_norm_clip", &PSContext::set_dp_norm_clip,
|
||||
"Set dp norm clip for federated learning secure aggregation.")
|
||||
.def("set_encrypt_type", &PSContext::set_encrypt_type,
|
||||
"Set encrypt type for federated learning secure aggregation.");
|
||||
|
||||
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
|
||||
.def(py::init())
|
||||
|
|
|
@ -184,6 +184,47 @@ void PSContext::set_server_mode(const std::string &server_mode) {
|
|||
|
||||
const std::string &PSContext::server_mode() const { return server_mode_; }
|
||||
|
||||
void PSContext::set_encrypt_type(const std::string &encrypt_type) {
|
||||
if (encrypt_type != kNotEncryptType && encrypt_type != kDPEncryptType && encrypt_type != kPWEncryptType) {
|
||||
MS_LOG(EXCEPTION) << encrypt_type << " is invalid. Encrypt type must be " << kNotEncryptType << " or "
|
||||
<< kDPEncryptType << " or " << kPWEncryptType;
|
||||
return;
|
||||
}
|
||||
encrypt_type_ = encrypt_type;
|
||||
}
|
||||
const std::string &PSContext::encrypt_type() const { return encrypt_type_; }
|
||||
|
||||
void PSContext::set_dp_eps(float dp_eps) {
|
||||
if (dp_eps > 0) {
|
||||
dp_eps_ = dp_eps;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << dp_eps << " is invalid, dp_eps must be larger than 0.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
float PSContext::dp_eps() const { return dp_eps_; }
|
||||
|
||||
void PSContext::set_dp_delta(float dp_delta) {
|
||||
if (dp_delta > 0 && dp_delta < 1) {
|
||||
dp_delta_ = dp_delta;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << dp_delta << " is invalid, dp_delta must be in range of (0, 1).";
|
||||
return;
|
||||
}
|
||||
}
|
||||
float PSContext::dp_delta() const { return dp_delta_; }
|
||||
|
||||
void PSContext::set_dp_norm_clip(float dp_norm_clip) {
|
||||
if (dp_norm_clip > 0) {
|
||||
dp_norm_clip_ = dp_norm_clip;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << dp_norm_clip << " is invalid, dp_norm_clip must be larger than 0.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
float PSContext::dp_norm_clip() const { return dp_norm_clip_; }
|
||||
|
||||
void PSContext::set_ms_role(const std::string &role) {
|
||||
if (server_mode_ != kServerModeFL && server_mode_ != kServerModeHybrid) {
|
||||
MS_LOG(EXCEPTION) << "Only federated learning supports to set role by fl context.";
|
||||
|
|
|
@ -35,9 +35,9 @@ constexpr char kEnvRoleOfServer[] = "MS_SERVER";
|
|||
constexpr char kEnvRoleOfWorker[] = "MS_WORKER";
|
||||
constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
|
||||
constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
|
||||
constexpr char kDPEncryptType[] = "DPEncrypt";
|
||||
constexpr char kPWEncryptType[] = "PWEncrypt";
|
||||
constexpr char kNotEncryptType[] = "NotEncrypt";
|
||||
constexpr char kDPEncryptType[] = "DP_ENCRYPT";
|
||||
constexpr char kPWEncryptType[] = "PW_ENCRYPT";
|
||||
constexpr char kNotEncryptType[] = "NOT_ENCRYPT";
|
||||
|
||||
// Use binary data to represent federated learning server's context so that we can judge which round resets the
|
||||
// iteration. From right to left, each bit stands for:
|
||||
|
@ -166,6 +166,18 @@ class PSContext {
|
|||
void set_config_file_path(const std::string &path);
|
||||
std::string config_file_path() const;
|
||||
|
||||
void set_dp_eps(float dp_eps);
|
||||
float dp_eps() const;
|
||||
|
||||
void set_dp_delta(float dp_delta);
|
||||
float dp_delta() const;
|
||||
|
||||
void set_dp_norm_clip(float dp_norm_clip);
|
||||
float dp_norm_clip() const;
|
||||
|
||||
void set_encrypt_type(const std::string &encrypt_type);
|
||||
const std::string &encrypt_type() const;
|
||||
|
||||
private:
|
||||
PSContext()
|
||||
: ps_enabled_(false),
|
||||
|
@ -199,7 +211,11 @@ class PSContext {
|
|||
secure_aggregation_(false),
|
||||
cluster_config_(nullptr),
|
||||
scheduler_manage_port_(11202),
|
||||
config_file_path_("") {}
|
||||
config_file_path_(""),
|
||||
dp_eps_(50),
|
||||
dp_delta_(0.01),
|
||||
dp_norm_clip_(1.0),
|
||||
encrypt_type_(kNotEncryptType) {}
|
||||
bool ps_enabled_;
|
||||
bool is_worker_;
|
||||
bool is_pserver_;
|
||||
|
@ -276,6 +292,18 @@ class PSContext {
|
|||
|
||||
// The path of the configuration file, used to configure the certification path and persistent storage type, etc.
|
||||
std::string config_file_path_;
|
||||
|
||||
// Epsilon budget of differential privacy mechanism. Used in federated learning for now.
|
||||
float dp_eps_;
|
||||
|
||||
// Delta budget of differential privacy mechanism. Used in federated learning for now.
|
||||
float dp_delta_;
|
||||
|
||||
// Norm clip factor of differential privacy mechanism. Used in federated learning for now.
|
||||
float dp_norm_clip_;
|
||||
|
||||
// Secure mechanism for federated learning. Used in federated learning for now.
|
||||
std::string encrypt_type_;
|
||||
};
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -133,7 +133,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
|
|||
|
||||
void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||
if (true) { // todo: PSContext::instance()->encrypt_type == PWEncrypt {
|
||||
if (PSContext::instance()->encrypt_type() == kPWEncryptType) {
|
||||
while (!Executor::GetInstance().IsAllWeightAggregationDone()) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||
}
|
||||
|
|
|
@ -96,9 +96,10 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHand
|
|||
size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize);
|
||||
MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is "
|
||||
<< total_data_size;
|
||||
|
||||
if (PSContext::instance()->encrypt_type() != kPWEncryptType) {
|
||||
FinishIteration();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) {
|
||||
|
|
|
@ -235,6 +235,10 @@ void Server::InitCipher() {
|
|||
unsigned char cipher_p[SECRET_MAX_LEN] = {0};
|
||||
int cipher_g = 1;
|
||||
unsigned char cipher_prime[PRIME_MAX_LEN] = {0};
|
||||
float dp_eps = PSContext::instance()->dp_eps();
|
||||
float dp_delta = PSContext::instance()->dp_delta();
|
||||
float dp_norm_clip = PSContext::instance()->dp_norm_clip();
|
||||
std::string encrypt_type = PSContext::instance()->encrypt_type();
|
||||
|
||||
mpz_t prim;
|
||||
mpz_init(prim);
|
||||
|
@ -248,10 +252,10 @@ void Server::InitCipher() {
|
|||
param.t = cipher_t;
|
||||
memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, SECRET_MAX_LEN);
|
||||
memcpy_s(param.prime, PRIME_MAX_LEN, cipher_prime, PRIME_MAX_LEN);
|
||||
// param.dp_delta = dp_delta;
|
||||
// param.dp_eps = dp_eps;
|
||||
// param.dp_norm_clip = dp_norm_clip;
|
||||
param.encrypt_type = kNotEncryptType; // PSContext::instance()->encrypt_type;
|
||||
param.dp_delta = dp_delta;
|
||||
param.dp_eps = dp_eps;
|
||||
param.dp_norm_clip = dp_norm_clip;
|
||||
param.encrypt_type = encrypt_type;
|
||||
cipher_init_->Init(param, 0, cipher_initial_client_cnt_, cipher_exchange_secrets_cnt_, cipher_share_secrets_cnt_,
|
||||
cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_,
|
||||
cipher_reconstruct_secrets_up_cnt_);
|
||||
|
|
|
@ -851,6 +851,17 @@ def set_fl_context(**kwargs):
|
|||
client_learning_rate (float): Client training learning rate. Default: 0.001.
|
||||
worker_step_num_per_iteration (int): The worker's standalone training step number before communicating with
|
||||
server. Default: 65.
|
||||
dp_eps (float): Epsilon budget of differential privacy mechanism. The smaller the dp_eps, the better the
|
||||
privacy protection effect. Default: 50.0.
|
||||
dp_delta (float): Delta budget of differential privacy mechanism, which is usually equals the reciprocal of
|
||||
client number. The smaller the dp_delta, the better the privacy protection effect. Default: 0.01.
|
||||
dp_norm_clip (float): A factor used for clipping model's weights for differential mechanism. Its value is
|
||||
suggested to be 0.5~2. Default: 1.0.
|
||||
encrypt_type (string): Secure schema for federated learning, which can be 'NOT_ENCRYPT', 'DP_ENCRYPT' or
|
||||
'PW_ENCRYPT'. If 'DP_ENCRYPT', differential privacy schema would be applied for clients and the privacy
|
||||
protection effect would be determined by dp_eps, dp_delta and dp_norm_clip as described above. If
|
||||
'PW_ENCRYPT', pairwise secure aggregation would be applied to protect clients' model from stealing.
|
||||
Default: 'NOT_ENCRYPT'.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not the attribute in federated learning mode context.
|
||||
|
|
|
@ -68,7 +68,11 @@ _set_ps_context_func_map = {
|
|||
"worker_step_num_per_iteration": ps_context().set_worker_step_num_per_iteration,
|
||||
"enable_ps_ssl": ps_context().set_enable_ssl,
|
||||
"scheduler_manage_port": ps_context().set_scheduler_manage_port,
|
||||
"config_file_path": ps_context().set_config_file_path
|
||||
"config_file_path": ps_context().set_config_file_path,
|
||||
"dp_eps": ps_context().set_dp_eps,
|
||||
"dp_delta": ps_context().set_dp_delta,
|
||||
"dp_norm_clip": ps_context().set_dp_norm_clip,
|
||||
"encrypt_type": ps_context().set_encrypt_type
|
||||
}
|
||||
|
||||
_get_ps_context_func_map = {
|
||||
|
|
|
@ -38,6 +38,11 @@ parser.add_argument("--client_batch_size", type=int, default=32)
|
|||
parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
||||
parser.add_argument("--local_server_num", type=int, default=-1)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--encrypt_type", type=str, default="NotEncrypt")
|
||||
# parameters for encrypt_type='DP_ENCRYPT'
|
||||
parser.add_argument("--dp_eps", type=float, default=50.0)
|
||||
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
|
||||
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, _ = parser.parse_known_args()
|
||||
|
@ -62,6 +67,10 @@ if __name__ == "__main__":
|
|||
client_learning_rate = args.client_learning_rate
|
||||
local_server_num = args.local_server_num
|
||||
config_file_path = args.config_file_path
|
||||
dp_eps = args.dp_eps
|
||||
dp_delta = args.dp_delta
|
||||
dp_norm_clip = args.dp_norm_clip
|
||||
encrypt_type = args.encrypt_type
|
||||
|
||||
if local_server_num == -1:
|
||||
local_server_num = server_num
|
||||
|
@ -95,6 +104,10 @@ if __name__ == "__main__":
|
|||
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
|
||||
cmd_server += " --client_batch_size=" + str(client_batch_size)
|
||||
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
|
||||
cmd_server += " --dp_eps=" + str(dp_eps)
|
||||
cmd_server += " --dp_delta=" + str(dp_delta)
|
||||
cmd_server += " --dp_norm_clip=" + str(dp_norm_clip)
|
||||
cmd_server += " --encrypt_type=" + str(encrypt_type)
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
import time
|
||||
|
|
|
@ -46,6 +46,10 @@ parser.add_argument("--client_batch_size", type=int, default=32)
|
|||
parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
||||
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
# parameters for encrypt_type='DP_ENCRYPT'
|
||||
parser.add_argument("--dp_eps", type=float, default=50.0)
|
||||
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
|
||||
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -70,6 +74,10 @@ client_batch_size = args.client_batch_size
|
|||
client_learning_rate = args.client_learning_rate
|
||||
scheduler_manage_port = args.scheduler_manage_port
|
||||
config_file_path = args.config_file_path
|
||||
dp_eps = args.dp_eps
|
||||
dp_delta = args.dp_delta
|
||||
dp_norm_clip = args.dp_norm_clip
|
||||
encrypt_type = args.encrypt_type
|
||||
|
||||
ctx = {
|
||||
"enable_fl": True,
|
||||
|
@ -93,7 +101,11 @@ ctx = {
|
|||
"client_batch_size": client_batch_size,
|
||||
"client_learning_rate": client_learning_rate,
|
||||
"scheduler_manage_port": scheduler_manage_port,
|
||||
"config_file_path": config_file_path
|
||||
"config_file_path": config_file_path,
|
||||
"dp_eps": dp_eps,
|
||||
"dp_delta": dp_delta,
|
||||
"dp_norm_clip": dp_norm_clip,
|
||||
"encrypt_type": encrypt_type
|
||||
}
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)
|
||||
|
|
Loading…
Reference in New Issue