Add secure parameters for mindspore federated learning.

This commit is contained in:
jin-xiulang 2021-07-03 13:26:19 +08:00
parent 0de3e841bf
commit ebc71d3306
10 changed files with 134 additions and 14 deletions

View File

@ -397,7 +397,13 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_config_file_path", &PSContext::set_config_file_path,
"Set configuration files required by the communication layer.")
.def("config_file_path", &PSContext::config_file_path,
"Get configuration files required by the communication layer.");
"Get configuration files required by the communication layer.")
.def("set_dp_eps", &PSContext::set_dp_eps, "Set dp epsilon for federated learning secure aggregation.")
.def("set_dp_delta", &PSContext::set_dp_delta, "Set dp delta for federated learning secure aggregation.")
.def("set_dp_norm_clip", &PSContext::set_dp_norm_clip,
"Set dp norm clip for federated learning secure aggregation.")
.def("set_encrypt_type", &PSContext::set_encrypt_type,
"Set encrypt type for federated learning secure aggregation.");
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
.def(py::init())

View File

@ -184,6 +184,47 @@ void PSContext::set_server_mode(const std::string &server_mode) {
const std::string &PSContext::server_mode() const { return server_mode_; }
void PSContext::set_encrypt_type(const std::string &encrypt_type) {
if (encrypt_type != kNotEncryptType && encrypt_type != kDPEncryptType && encrypt_type != kPWEncryptType) {
MS_LOG(EXCEPTION) << encrypt_type << " is invalid. Encrypt type must be " << kNotEncryptType << " or "
<< kDPEncryptType << " or " << kPWEncryptType;
return;
}
encrypt_type_ = encrypt_type;
}
const std::string &PSContext::encrypt_type() const { return encrypt_type_; }
void PSContext::set_dp_eps(float dp_eps) {
if (dp_eps > 0) {
dp_eps_ = dp_eps;
} else {
MS_LOG(EXCEPTION) << dp_eps << " is invalid, dp_eps must be larger than 0.";
return;
}
}
float PSContext::dp_eps() const { return dp_eps_; }
void PSContext::set_dp_delta(float dp_delta) {
if (dp_delta > 0 && dp_delta < 1) {
dp_delta_ = dp_delta;
} else {
MS_LOG(EXCEPTION) << dp_delta << " is invalid, dp_delta must be in range of (0, 1).";
return;
}
}
float PSContext::dp_delta() const { return dp_delta_; }
void PSContext::set_dp_norm_clip(float dp_norm_clip) {
if (dp_norm_clip > 0) {
dp_norm_clip_ = dp_norm_clip;
} else {
MS_LOG(EXCEPTION) << dp_norm_clip << " is invalid, dp_norm_clip must be larger than 0.";
return;
}
}
float PSContext::dp_norm_clip() const { return dp_norm_clip_; }
void PSContext::set_ms_role(const std::string &role) {
if (server_mode_ != kServerModeFL && server_mode_ != kServerModeHybrid) {
MS_LOG(EXCEPTION) << "Only federated learning supports to set role by fl context.";

View File

@ -35,9 +35,9 @@ constexpr char kEnvRoleOfServer[] = "MS_SERVER";
constexpr char kEnvRoleOfWorker[] = "MS_WORKER";
constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
constexpr char kDPEncryptType[] = "DPEncrypt";
constexpr char kPWEncryptType[] = "PWEncrypt";
constexpr char kNotEncryptType[] = "NotEncrypt";
constexpr char kDPEncryptType[] = "DP_ENCRYPT";
constexpr char kPWEncryptType[] = "PW_ENCRYPT";
constexpr char kNotEncryptType[] = "NOT_ENCRYPT";
// Use binary data to represent federated learning server's context so that we can judge which round resets the
// iteration. From right to left, each bit stands for:
@ -166,6 +166,18 @@ class PSContext {
void set_config_file_path(const std::string &path);
std::string config_file_path() const;
void set_dp_eps(float dp_eps);
float dp_eps() const;
void set_dp_delta(float dp_delta);
float dp_delta() const;
void set_dp_norm_clip(float dp_norm_clip);
float dp_norm_clip() const;
void set_encrypt_type(const std::string &encrypt_type);
const std::string &encrypt_type() const;
private:
PSContext()
: ps_enabled_(false),
@ -199,7 +211,11 @@ class PSContext {
secure_aggregation_(false),
cluster_config_(nullptr),
scheduler_manage_port_(11202),
config_file_path_("") {}
config_file_path_(""),
dp_eps_(50),
dp_delta_(0.01),
dp_norm_clip_(1.0),
encrypt_type_(kNotEncryptType) {}
bool ps_enabled_;
bool is_worker_;
bool is_pserver_;
@ -276,6 +292,18 @@ class PSContext {
// The path of the configuration file, used to configure the certification path and persistent storage type, etc.
std::string config_file_path_;
// Epsilon budget of differential privacy mechanism. Used in federated learning for now.
float dp_eps_;
// Delta budget of differential privacy mechanism. Used in federated learning for now.
float dp_delta_;
// Norm clip factor of differential privacy mechanism. Used in federated learning for now.
float dp_norm_clip_;
// Secure mechanism for federated learning. Used in federated learning for now.
std::string encrypt_type_;
};
} // namespace ps
} // namespace mindspore

View File

@ -133,7 +133,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
if (true) { // todo: PSContext::instance()->encrypt_type == PWEncrypt {
if (PSContext::instance()->encrypt_type() == kPWEncryptType) {
while (!Executor::GetInstance().IsAllWeightAggregationDone()) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}

View File

@ -96,9 +96,10 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHand
size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize);
MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is "
<< total_data_size;
if (PSContext::instance()->encrypt_type() != kPWEncryptType) {
FinishIteration();
}
}
}
bool UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb) {

View File

@ -235,6 +235,10 @@ void Server::InitCipher() {
unsigned char cipher_p[SECRET_MAX_LEN] = {0};
int cipher_g = 1;
unsigned char cipher_prime[PRIME_MAX_LEN] = {0};
float dp_eps = PSContext::instance()->dp_eps();
float dp_delta = PSContext::instance()->dp_delta();
float dp_norm_clip = PSContext::instance()->dp_norm_clip();
std::string encrypt_type = PSContext::instance()->encrypt_type();
mpz_t prim;
mpz_init(prim);
@ -248,10 +252,10 @@ void Server::InitCipher() {
param.t = cipher_t;
memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, SECRET_MAX_LEN);
memcpy_s(param.prime, PRIME_MAX_LEN, cipher_prime, PRIME_MAX_LEN);
// param.dp_delta = dp_delta;
// param.dp_eps = dp_eps;
// param.dp_norm_clip = dp_norm_clip;
param.encrypt_type = kNotEncryptType; // PSContext::instance()->encrypt_type;
param.dp_delta = dp_delta;
param.dp_eps = dp_eps;
param.dp_norm_clip = dp_norm_clip;
param.encrypt_type = encrypt_type;
cipher_init_->Init(param, 0, cipher_initial_client_cnt_, cipher_exchange_secrets_cnt_, cipher_share_secrets_cnt_,
cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_,
cipher_reconstruct_secrets_up_cnt_);

View File

@ -851,6 +851,17 @@ def set_fl_context(**kwargs):
client_learning_rate (float): Client training learning rate. Default: 0.001.
worker_step_num_per_iteration (int): The worker's standalone training step number before communicating with
server. Default: 65.
dp_eps (float): Epsilon budget of differential privacy mechanism. The smaller the dp_eps, the better the
privacy protection effect. Default: 50.0.
dp_delta (float): Delta budget of differential privacy mechanism, which is usually equals the reciprocal of
client number. The smaller the dp_delta, the better the privacy protection effect. Default: 0.01.
dp_norm_clip (float): A factor used for clipping model's weights for differential mechanism. Its value is
suggested to be 0.5~2. Default: 1.0.
encrypt_type (string): Secure schema for federated learning, which can be 'NOT_ENCRYPT', 'DP_ENCRYPT' or
'PW_ENCRYPT'. If 'DP_ENCRYPT', differential privacy schema would be applied for clients and the privacy
protection effect would be determined by dp_eps, dp_delta and dp_norm_clip as described above. If
'PW_ENCRYPT', pairwise secure aggregation would be applied to protect clients' model from stealing.
Default: 'NOT_ENCRYPT'.
Raises:
ValueError: If input key is not the attribute in federated learning mode context.

View File

@ -68,7 +68,11 @@ _set_ps_context_func_map = {
"worker_step_num_per_iteration": ps_context().set_worker_step_num_per_iteration,
"enable_ps_ssl": ps_context().set_enable_ssl,
"scheduler_manage_port": ps_context().set_scheduler_manage_port,
"config_file_path": ps_context().set_config_file_path
"config_file_path": ps_context().set_config_file_path,
"dp_eps": ps_context().set_dp_eps,
"dp_delta": ps_context().set_dp_delta,
"dp_norm_clip": ps_context().set_dp_norm_clip,
"encrypt_type": ps_context().set_encrypt_type
}
_get_ps_context_func_map = {

View File

@ -38,6 +38,11 @@ parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--local_server_num", type=int, default=-1)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--encrypt_type", type=str, default="NotEncrypt")
# parameters for encrypt_type='DP_ENCRYPT'
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
if __name__ == "__main__":
args, _ = parser.parse_known_args()
@ -62,6 +67,10 @@ if __name__ == "__main__":
client_learning_rate = args.client_learning_rate
local_server_num = args.local_server_num
config_file_path = args.config_file_path
dp_eps = args.dp_eps
dp_delta = args.dp_delta
dp_norm_clip = args.dp_norm_clip
encrypt_type = args.encrypt_type
if local_server_num == -1:
local_server_num = server_num
@ -95,6 +104,10 @@ if __name__ == "__main__":
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
cmd_server += " --client_batch_size=" + str(client_batch_size)
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
cmd_server += " --dp_eps=" + str(dp_eps)
cmd_server += " --dp_delta=" + str(dp_delta)
cmd_server += " --dp_norm_clip=" + str(dp_norm_clip)
cmd_server += " --encrypt_type=" + str(encrypt_type)
cmd_server += " > server.log 2>&1 &"
import time

View File

@ -46,6 +46,10 @@ parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
parser.add_argument("--config_file_path", type=str, default="")
# parameters for encrypt_type='DP_ENCRYPT'
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -70,6 +74,10 @@ client_batch_size = args.client_batch_size
client_learning_rate = args.client_learning_rate
scheduler_manage_port = args.scheduler_manage_port
config_file_path = args.config_file_path
dp_eps = args.dp_eps
dp_delta = args.dp_delta
dp_norm_clip = args.dp_norm_clip
encrypt_type = args.encrypt_type
ctx = {
"enable_fl": True,
@ -93,7 +101,11 @@ ctx = {
"client_batch_size": client_batch_size,
"client_learning_rate": client_learning_rate,
"scheduler_manage_port": scheduler_manage_port,
"config_file_path": config_file_path
"config_file_path": config_file_path,
"dp_eps": dp_eps,
"dp_delta": dp_delta,
"dp_norm_clip": dp_norm_clip,
"encrypt_type": encrypt_type
}
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)