forked from mindspore-Ecosystem/mindspore
high level api for federated learning
This commit is contained in:
parent
7587544c0e
commit
ea3e651b36
|
@ -28,6 +28,8 @@ from ._time_monitor import TimeMonitor
|
|||
from ._summary_collector import SummaryCollector
|
||||
from ._lr_scheduler_callback import LearningRateScheduler
|
||||
from ._landscape import SummaryLandscape
|
||||
from ._fl_manager import FederatedLearningManager
|
||||
|
||||
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint",
|
||||
"SummaryCollector", "CheckpointConfig", "RunContext", "LearningRateScheduler", "SummaryLandscape"]
|
||||
"SummaryCollector", "CheckpointConfig", "RunContext", "LearningRateScheduler", "SummaryLandscape",
|
||||
"FederatedLearningManager"]
|
||||
|
|
|
@ -0,0 +1,256 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FederatedLearningManager related class and functions."""
|
||||
|
||||
from copy import deepcopy
|
||||
import numpy as np
|
||||
from mindspore import context, nn
|
||||
from mindspore.common import Parameter, ParameterTuple
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore._checkparam import Validator, Rel
|
||||
|
||||
|
||||
class _StartFLJob(nn.Cell):
|
||||
"""
|
||||
StartFLJob for Federated Learning Worker.
|
||||
"""
|
||||
def __init__(self, data_size):
|
||||
super(_StartFLJob, self).__init__()
|
||||
self.start_fl_job = P.StartFLJob(data_size)
|
||||
|
||||
def construct(self):
|
||||
succ = self.start_fl_job()
|
||||
return succ
|
||||
|
||||
|
||||
class _UpdateAndGetModel(nn.Cell):
|
||||
"""
|
||||
Update and Get Model for Federated Learning Worker.
|
||||
"""
|
||||
def __init__(self, weights):
|
||||
super(_UpdateAndGetModel, self).__init__()
|
||||
self.update_model = P.UpdateModel()
|
||||
self.get_model = P.GetModel()
|
||||
self.weights = weights
|
||||
|
||||
def construct(self):
|
||||
self.update_model(self.weights)
|
||||
succ = self.get_model(self.weights)
|
||||
return succ
|
||||
|
||||
|
||||
class _ExchangeKeys(nn.Cell):
|
||||
"""
|
||||
Exchange Keys for Stable PW Encrypt.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ExchangeKeys, self).__init__()
|
||||
self.exchange_keys = P.ExchangeKeys()
|
||||
|
||||
def construct(self):
|
||||
return self.exchange_keys()
|
||||
|
||||
|
||||
class _GetKeys(nn.Cell):
|
||||
"""
|
||||
Get Keys for Stable PW Encrypt.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(GetKeys, self).__init__()
|
||||
self.get_keys = P.GetKeys()
|
||||
|
||||
def construct(self):
|
||||
return self.get_keys()
|
||||
|
||||
|
||||
class FederatedLearningManager(Callback):
|
||||
"""
|
||||
Manage Federated Learning during training.
|
||||
|
||||
Args:
|
||||
model (nn.Cell): A training model.
|
||||
sync_frequency (int): Synchronization frequency of parameters in Federated Learning.
|
||||
Note:
|
||||
In dataset sink mode, the unit of the frequency is the number of epochs.
|
||||
Otherwise, the unit of the frequency is the number of steps.
|
||||
sync_type (str): Parameter synchronization type in Federated Learning.
|
||||
Supports ["fixed", "adaptive"]. Default: "fixed".
|
||||
|
||||
- fixed: The frequency of parameter synchronization is fixed.
|
||||
- adaptive: The frequency of parameter synchronization changes adaptively.
|
||||
|
||||
Note:
|
||||
This is an experimental prototype that is subject to change.
|
||||
"""
|
||||
|
||||
def __init__(self, model, sync_frequency, sync_type='fixed', **kwargs):
|
||||
super(FederatedLearningManager, self).__init__()
|
||||
server_mode = context.get_fl_context("server_mode")
|
||||
if server_mode not in ("FEDERATED_LEARNING", "HYBRID_TRAINING"):
|
||||
raise ValueError("server_mode must in (\"FEDERATED_LEARNING\", \"HYBRID_TRAINING\")")
|
||||
Validator.check_isinstance('model', model, nn.Cell)
|
||||
Validator.check_positive_int(sync_frequency)
|
||||
Validator.check_string(sync_type, ["fixed", "adaptive"])
|
||||
self._model = model
|
||||
self._sync_frequency = sync_frequency
|
||||
self._next_sync_iter_id = self._sync_frequency
|
||||
self._sync_type = sync_type
|
||||
self._global_step = 0
|
||||
self._data_size = 0
|
||||
|
||||
if self._is_adaptive_sync():
|
||||
self._as_set_init_state(kwargs)
|
||||
self._as_wrap_cell()
|
||||
|
||||
def _is_adaptive_sync(self):
|
||||
"""
|
||||
Determine whether adaptive frequency synchronization is required.
|
||||
"""
|
||||
return self._sync_type == "adaptive"
|
||||
|
||||
def _as_set_init_state(self, kwargs):
|
||||
"""
|
||||
Setting the initial state for adaptive synchronization.
|
||||
"""
|
||||
self._as_prefix = "as_abs_grad."
|
||||
|
||||
self._min_consistent_rate = kwargs.get("min_consistent_rate", 1.1)
|
||||
Validator.check_non_negative_float(self._min_consistent_rate)
|
||||
self._min_consistent_rate_at_round = kwargs.get("min_consistent_rate_at_round", 0)
|
||||
Validator.check_non_negative_int(self._min_consistent_rate_at_round)
|
||||
self._ema_alpha = kwargs.get("ema_alpha", 0.5)
|
||||
Validator.check_float_range(self._ema_alpha, 0.0, 1.0, Rel.INC_NEITHER)
|
||||
self._observation_window_size = kwargs.get("observation_window_size", 5)
|
||||
Validator.check_positive_int(self._observation_window_size)
|
||||
self._frequency_increase_ratio = kwargs.get("frequency_increase_ratio", 2)
|
||||
Validator.check_positive_int(self._frequency_increase_ratio)
|
||||
self._unchanged_round = kwargs.get("unchanged_round", 0)
|
||||
Validator.check_non_negative_int(self._unchanged_round)
|
||||
|
||||
self._round_id = 0
|
||||
self._last_param = {_.name: deepcopy(_.asnumpy()) for _ in self._model.trainable_params()
|
||||
if self._as_prefix not in _.name}
|
||||
self._model_size = 0
|
||||
self._grads_ema = dict()
|
||||
self._abs_grads_ema = dict()
|
||||
for param in self._model.trainable_params():
|
||||
if self._as_prefix not in param.name:
|
||||
self._model_size += np.product(param.shape)
|
||||
self._grads_ema[param.name] = np.zeros(param.shape)
|
||||
self._abs_grads_ema[param.name] = np.zeros(param.shape)
|
||||
self._model_size = float(self._model_size)
|
||||
|
||||
def _as_wrap_cell(self):
|
||||
"""
|
||||
Wrap Cell for adaptive synchronization.
|
||||
"""
|
||||
param_list = list()
|
||||
for param in self._model.trainable_params():
|
||||
new_param = param.clone()
|
||||
new_param.name = self._as_prefix + param.name
|
||||
param_list.append(new_param)
|
||||
for param in param_list:
|
||||
self._model.insert_param_to_cell(param.name, param, False)
|
||||
|
||||
def _as_set_grads(self):
|
||||
"""
|
||||
Set the absolute value of the gradient for adaptive synchronization.
|
||||
"""
|
||||
abs_grads = dict()
|
||||
for param in self._model.trainable_params():
|
||||
if self._as_prefix not in param.name:
|
||||
abs_grads[self._as_prefix+param.name] = np.abs(param.asnumpy() - self._last_param[param.name])
|
||||
for param in self._model.trainable_params():
|
||||
if self._as_prefix in param.name:
|
||||
param.set_data(Parameter(abs_grads[param.name]))
|
||||
|
||||
def _as_analyze_gradient(self):
|
||||
"""
|
||||
Analysis of relevant statistics based on gradient for adaptive synchronization.
|
||||
"""
|
||||
worker_num = context.get_fl_context("worker_num")
|
||||
ema_alpha = self._ema_alpha
|
||||
consistent_rate_sum = 0.0
|
||||
grads = dict()
|
||||
abs_grads = dict()
|
||||
for param in self._model.trainable_params():
|
||||
if self._as_prefix in param.name:
|
||||
abs_grads[param.name.replace(self._as_prefix, '')] = param.asnumpy() * worker_num
|
||||
else:
|
||||
grads[param.name] = (param.asnumpy() - self._last_param[param.name]) * worker_num
|
||||
for last_p in self._last_param:
|
||||
self._grads_ema[last_p] = ema_alpha * self._grads_ema[last_p] + (1 - ema_alpha) * grads[last_p]
|
||||
self._abs_grads_ema[last_p] = ema_alpha * self._abs_grads_ema[last_p] + (1 - ema_alpha) * abs_grads[last_p]
|
||||
divide_base = np.where(self._abs_grads_ema[last_p] == 0,
|
||||
np.ones(self._abs_grads_ema[last_p].shape), self._abs_grads_ema[last_p])
|
||||
layer_consistent_rate = np.abs(self._grads_ema[last_p]) / divide_base
|
||||
consistent_rate_sum += np.sum(layer_consistent_rate)
|
||||
|
||||
consistent_rate = float(consistent_rate_sum / self._model_size)
|
||||
|
||||
if self._min_consistent_rate > consistent_rate:
|
||||
self._min_consistent_rate = consistent_rate
|
||||
self._min_consistent_rate_at_round = self._round_id
|
||||
else:
|
||||
if self._round_id - self._min_consistent_rate_at_round > self._observation_window_size:
|
||||
if self._sync_frequency > 1 and self._round_id > self._unchanged_round:
|
||||
self._sync_frequency = (self._sync_frequency + self._frequency_increase_ratio - 1) \
|
||||
// self._frequency_increase_ratio
|
||||
self._min_consistent_rate = 1.1
|
||||
self._min_consistent_rate_at_round = self._round_id
|
||||
self._observation_window_size *= self._frequency_increase_ratio
|
||||
|
||||
for param in self._model.trainable_params():
|
||||
if self._as_prefix not in param.name:
|
||||
self._grads_ema[param.name] = np.zeros(param.shape)
|
||||
self._abs_grads_ema[param.name] = np.zeros(param.shape)
|
||||
|
||||
def _as_set_last_param(self):
|
||||
"""
|
||||
Set the value of last parameters for adaptive synchronization.
|
||||
"""
|
||||
self._last_param = {_.name: deepcopy(_.asnumpy()) for _ in self._model.trainable_params()
|
||||
if self._as_prefix not in _.name}
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Synchronization parameters at the end of step. If sync_type is "adaptive", the synchronous frequency is
|
||||
adaptively adjusted here.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the train running.
|
||||
"""
|
||||
self._global_step += 1
|
||||
cb_params = run_context.original_args()
|
||||
inputs = cb_params.train_dataset_element
|
||||
batch_size = inputs[0].shape[0] if isinstance(inputs, (tuple, list)) else inputs.shape[0]
|
||||
self._data_size += batch_size
|
||||
if context.get_fl_context("ms_role") == "MS_WORKER":
|
||||
if self._global_step == self._next_sync_iter_id:
|
||||
start_fl_job = _StartFLJob(self._data_size)
|
||||
start_fl_job()
|
||||
self._data_size = 0
|
||||
if self._is_adaptive_sync():
|
||||
self._as_set_grads()
|
||||
update_and_get_model = _UpdateAndGetModel(ParameterTuple(self._model.trainable_params()))
|
||||
update_and_get_model()
|
||||
self._next_sync_iter_id = self._global_step + self._sync_frequency
|
||||
if self._is_adaptive_sync():
|
||||
self._as_analyze_gradient()
|
||||
self._round_id += 1
|
||||
self._as_set_last_param()
|
||||
|
||||
print("sync step is: {}".format(self._global_step))
|
|
@ -150,6 +150,7 @@ config_file_path: ""
|
|||
encrypt_type: "NOT_ENCRYPT"
|
||||
dataset_path: ""
|
||||
user_id: 0
|
||||
sync_type: fixed
|
||||
|
||||
|
||||
# Number of threads used to process the dataset in parallel
|
||||
|
|
|
@ -26,6 +26,7 @@ parser.add_argument("--scheduler_port", type=int, default=8113)
|
|||
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--dataset_path", type=str, default="")
|
||||
parser.add_argument("--sync_type", type=str, default="fixed", choices=["fixed", "adaptive"])
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -37,6 +38,7 @@ scheduler_port = args.scheduler_port
|
|||
scheduler_manage_port = args.scheduler_manage_port
|
||||
config_file_path = args.config_file_path
|
||||
dataset_path = args.dataset_path
|
||||
sync_type = args.sync_type
|
||||
|
||||
cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&"
|
||||
cmd_sched += "mkdir ${execute_path}/scheduler/ &&"
|
||||
|
@ -53,6 +55,7 @@ cmd_sched += " --scheduler_port=" + str(scheduler_port)
|
|||
cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port)
|
||||
cmd_sched += " --dataset_path=" + str(dataset_path)
|
||||
cmd_sched += " --user_id=" + str(0)
|
||||
cmd_sched += " --sync_type=" + sync_type
|
||||
cmd_sched += " > scheduler.log 2>&1 &"
|
||||
|
||||
subprocess.call(['bash', '-c', cmd_sched])
|
||||
|
|
|
@ -36,6 +36,7 @@ parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
|||
parser.add_argument("--local_server_num", type=int, default=-1)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
parser.add_argument("--sync_type", type=str, default="fixed", choices=["fixed", "adaptive"])
|
||||
|
||||
parser.add_argument("--dataset_path", type=str, default="")
|
||||
|
||||
|
@ -59,6 +60,7 @@ client_learning_rate = args.client_learning_rate
|
|||
local_server_num = args.local_server_num
|
||||
config_file_path = args.config_file_path
|
||||
encrypt_type = args.encrypt_type
|
||||
sync_type = args.sync_type
|
||||
|
||||
dataset_path = args.dataset_path
|
||||
|
||||
|
@ -94,6 +96,7 @@ for i in range(local_server_num):
|
|||
cmd_server += " --encrypt_type=" + str(encrypt_type)
|
||||
cmd_server += " --dataset_path=" + str(dataset_path)
|
||||
cmd_server += " --user_id=" + str(0)
|
||||
cmd_server += " --sync_type=" + sync_type
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
|
|||
parser.add_argument("--local_worker_num", type=int, default=-1)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--dataset_path", type=str, default="")
|
||||
parser.add_argument("--sync_type", type=str, default="fixed", choices=["fixed", "adaptive"])
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -43,6 +44,7 @@ worker_step_num_per_iteration = args.worker_step_num_per_iteration
|
|||
local_worker_num = args.local_worker_num
|
||||
config_file_path = args.config_file_path
|
||||
dataset_path = args.dataset_path
|
||||
sync_type = args.sync_type
|
||||
|
||||
if local_worker_num == -1:
|
||||
local_worker_num = worker_num
|
||||
|
@ -68,6 +70,7 @@ for i in range(local_worker_num):
|
|||
cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration)
|
||||
cmd_worker += " --dataset_path=" + str(dataset_path)
|
||||
cmd_worker += " --user_id=" + str(i)
|
||||
cmd_worker += " --sync_type=" + sync_type
|
||||
cmd_worker += " > worker.log 2>&1 &"
|
||||
|
||||
subprocess.call(['bash', '-c', cmd_worker])
|
||||
|
|
|
@ -21,9 +21,7 @@ import numpy as np
|
|||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context, Tensor, Parameter
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train.callback import TimeMonitor
|
||||
from mindspore.train.callback import TimeMonitor, FederatedLearningManager
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
|
@ -59,6 +57,7 @@ worker_step_num_per_iteration = config.worker_step_num_per_iteration
|
|||
scheduler_manage_port = config.scheduler_manage_port
|
||||
config_file_path = config.config_file_path
|
||||
encrypt_type = config.encrypt_type
|
||||
sync_type = config.sync_type
|
||||
|
||||
user_id = config.user_id
|
||||
|
||||
|
@ -125,25 +124,6 @@ def train_fasterrcnn_():
|
|||
|
||||
return dataset_size, dataset
|
||||
|
||||
class StartFLJob(nn.Cell):
|
||||
def __init__(self, data_size):
|
||||
super(StartFLJob, self).__init__()
|
||||
self.start_fl_job = P.StartFLJob(data_size)
|
||||
|
||||
def construct(self):
|
||||
return self.start_fl_job()
|
||||
|
||||
class UpdateAndGetModel(nn.Cell):
|
||||
def __init__(self, weights):
|
||||
super(UpdateAndGetModel, self).__init__()
|
||||
self.update_model = P.UpdateModel()
|
||||
self.get_model = P.GetModel()
|
||||
self.weights = weights
|
||||
|
||||
def construct(self):
|
||||
self.update_model(self.weights)
|
||||
get_model = self.get_model(self.weights)
|
||||
return get_model
|
||||
|
||||
def train():
|
||||
""" train_fasterrcnn """
|
||||
|
@ -184,6 +164,11 @@ def train():
|
|||
|
||||
loss = LossNet()
|
||||
lr = Tensor(dynamic_lr(config, dataset_size), mstype.float32)
|
||||
federated_learning_manager = FederatedLearningManager(
|
||||
net,
|
||||
sync_frequency=config.client_epoch_num * dataset_size,
|
||||
sync_type=sync_type
|
||||
)
|
||||
opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,
|
||||
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
|
@ -194,7 +179,7 @@ def train():
|
|||
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale)
|
||||
time_cb = TimeMonitor(data_size=dataset_size)
|
||||
loss_cb = LossCallBack(rank_id=rank)
|
||||
cb = [time_cb, loss_cb]
|
||||
cb = [federated_learning_manager, time_cb, loss_cb]
|
||||
|
||||
model = Model(net)
|
||||
ckpt_path1 = os.path.join("ckpt", user)
|
||||
|
@ -202,13 +187,7 @@ def train():
|
|||
os.makedirs(ckpt_path1)
|
||||
print("====================", config.client_epoch_num, fl_iteration_num, flush=True)
|
||||
for iter_num in range(fl_iteration_num):
|
||||
if context.get_fl_context("ms_role") == "MS_WORKER":
|
||||
start_fl_job = StartFLJob(dataset_size * config.batch_size)
|
||||
start_fl_job()
|
||||
model.train(config.client_epoch_num, dataset, callbacks=cb)
|
||||
if context.get_fl_context("ms_role") == "MS_WORKER":
|
||||
update_and_get_model = UpdateAndGetModel(opt.parameters)
|
||||
update_and_get_model()
|
||||
ckpt_name = user + "-fast-rcnn-" + str(iter_num) + "epoch.ckpt"
|
||||
ckpt_path = os.path.join(ckpt_path1, ckpt_name)
|
||||
save_checkpoint(net, ckpt_path)
|
||||
|
|
|
@ -26,6 +26,7 @@ parser.add_argument("--scheduler_port", type=int, default=8113)
|
|||
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--dataset_path", type=str, default="")
|
||||
parser.add_argument("--sync_type", type=str, default="fixed", choices=["fixed", "adaptive"])
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -37,6 +38,7 @@ scheduler_port = args.scheduler_port
|
|||
scheduler_manage_port = args.scheduler_manage_port
|
||||
config_file_path = args.config_file_path
|
||||
dataset_path = args.dataset_path
|
||||
sync_type = args.sync_type
|
||||
|
||||
cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&"
|
||||
cmd_sched += "mkdir ${execute_path}/scheduler/ &&"
|
||||
|
@ -53,6 +55,7 @@ cmd_sched += " --scheduler_port=" + str(scheduler_port)
|
|||
cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port)
|
||||
cmd_sched += " --dataset_path=" + str(dataset_path)
|
||||
cmd_sched += " --user_id=" + str(0)
|
||||
cmd_sched += " --sync_type=" + sync_type
|
||||
cmd_sched += " > scheduler.log 2>&1 &"
|
||||
|
||||
subprocess.call(['bash', '-c', cmd_sched])
|
||||
|
|
|
@ -37,6 +37,7 @@ parser.add_argument("--local_server_num", type=int, default=-1)
|
|||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
parser.add_argument("--dataset_path", type=str, default="")
|
||||
parser.add_argument("--sync_type", type=str, default="fixed", choices=["fixed", "adaptive"])
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -59,6 +60,7 @@ local_server_num = args.local_server_num
|
|||
config_file_path = args.config_file_path
|
||||
encrypt_type = args.encrypt_type
|
||||
dataset_path = args.dataset_path
|
||||
sync_type = args.sync_type
|
||||
|
||||
if local_server_num == -1:
|
||||
local_server_num = server_num
|
||||
|
@ -92,6 +94,7 @@ for i in range(local_server_num):
|
|||
cmd_server += " --encrypt_type=" + str(encrypt_type)
|
||||
cmd_server += " --dataset_path=" + str(dataset_path)
|
||||
cmd_server += " --user_id=" + str(0)
|
||||
cmd_server += " --sync_type=" + sync_type
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
import time
|
||||
|
|
|
@ -32,6 +32,7 @@ parser.add_argument("--local_worker_num", type=int, default=-1)
|
|||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--dataset_path", type=str, default="")
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
parser.add_argument("--sync_type", type=str, default="fixed", choices=["fixed", "adaptive"])
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -49,6 +50,7 @@ local_worker_num = args.local_worker_num
|
|||
config_file_path = args.config_file_path
|
||||
dataset_path = args.dataset_path
|
||||
encrypt_type = args.encrypt_type
|
||||
sync_type = args.sync_type
|
||||
|
||||
if local_worker_num == -1:
|
||||
local_worker_num = worker_num
|
||||
|
@ -77,6 +79,7 @@ for i in range(local_worker_num):
|
|||
cmd_worker += " --dataset_path=" + str(dataset_path)
|
||||
cmd_worker += " --encrypt_type=" + str(encrypt_type)
|
||||
cmd_worker += " --user_id=" + str(i)
|
||||
cmd_worker += " --sync_type=" + sync_type
|
||||
cmd_worker += " > worker.log 2>&1 &"
|
||||
|
||||
subprocess.call(['bash', '-c', cmd_worker])
|
||||
|
|
|
@ -28,8 +28,7 @@ import mindspore.dataset.vision.py_transforms as PV
|
|||
import mindspore.dataset.transforms.py_transforms as PT
|
||||
import mindspore.dataset.transforms.c_transforms as tC
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.train.callback import Callback, FederatedLearningManager
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train import Model
|
||||
|
||||
|
@ -62,6 +61,7 @@ parser.add_argument("--cipher_time_window", type=int, default=300000)
|
|||
parser.add_argument("--dataset_path", type=str, default="")
|
||||
# The user_id is used to set each worker's dataset path.
|
||||
parser.add_argument("--user_id", type=str, default="0")
|
||||
parser.add_argument("--sync_type", type=str, default="fixed", choices=["fixed", "adaptive"])
|
||||
|
||||
parser.add_argument('--img_size', type=int, default=(32, 32, 1), help='the image size of (h,w,c)')
|
||||
parser.add_argument('--repeat_size', type=int, default=1, help='the repeat size when create the dataLoader')
|
||||
|
@ -91,6 +91,7 @@ encrypt_type = args.encrypt_type
|
|||
cipher_time_window = args.cipher_time_window
|
||||
dataset_path = args.dataset_path
|
||||
user_id = args.user_id
|
||||
sync_type = args.sync_type
|
||||
|
||||
ctx = {
|
||||
"enable_fl": True,
|
||||
|
@ -279,56 +280,11 @@ def evalute_process(model, eval_data, img_size, batch_size):
|
|||
return acc['Accuracy'], acc['Loss']
|
||||
|
||||
|
||||
class StartFLJob(nn.Cell):
|
||||
def __init__(self, data_size):
|
||||
super(StartFLJob, self).__init__()
|
||||
self.start_fl_job = P.StartFLJob(data_size)
|
||||
|
||||
def construct(self):
|
||||
return self.start_fl_job()
|
||||
|
||||
|
||||
class ExchangeKeys(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ExchangeKeys, self).__init__()
|
||||
self.exchange_keys = P.ExchangeKeys()
|
||||
|
||||
def construct(self):
|
||||
return self.exchange_keys()
|
||||
|
||||
|
||||
class GetKeys(nn.Cell):
|
||||
def __init__(self):
|
||||
super(GetKeys, self).__init__()
|
||||
self.get_keys = P.GetKeys()
|
||||
|
||||
def construct(self):
|
||||
return self.get_keys()
|
||||
|
||||
|
||||
class UpdateAndGetModel(nn.Cell):
|
||||
def __init__(self, weights):
|
||||
super(UpdateAndGetModel, self).__init__()
|
||||
self.update_model = P.UpdateModel()
|
||||
self.get_model = P.GetModel()
|
||||
self.weights = weights
|
||||
|
||||
def construct(self):
|
||||
self.update_model(self.weights)
|
||||
get_model = self.get_model(self.weights)
|
||||
return get_model
|
||||
|
||||
|
||||
def train():
|
||||
epoch = fl_iteration_num
|
||||
network = LeNet5(62, 3)
|
||||
|
||||
# define the loss function
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
# define the optimizer
|
||||
net_opt = nn.Momentum(network.trainable_params(), client_learning_rate, 0.9)
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy(), 'Loss': nn.Loss()})
|
||||
|
||||
# construct dataset
|
||||
ds.config.set_seed(1)
|
||||
data_root_path = dataset_path
|
||||
user = "dataset_" + user_id
|
||||
|
@ -339,29 +295,27 @@ def train():
|
|||
print("size is ", dataset.get_dataset_size(), flush=True)
|
||||
num_batches = dataset.get_dataset_size()
|
||||
|
||||
# define the loss function
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
# define fl manager
|
||||
federated_learning_manager = FederatedLearningManager(
|
||||
network,
|
||||
sync_frequency=epoch * num_batches,
|
||||
sync_type=sync_type,
|
||||
)
|
||||
# define the optimizer
|
||||
net_opt = nn.Momentum(network.trainable_params(), client_learning_rate, 0.9)
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy(), 'Loss': nn.Loss()})
|
||||
|
||||
loss_cb = LossGet(1, num_batches)
|
||||
cbs = []
|
||||
cbs = list()
|
||||
cbs.append(federated_learning_manager)
|
||||
cbs.append(loss_cb)
|
||||
ckpt_path = "ckpt"
|
||||
os.makedirs(ckpt_path)
|
||||
|
||||
for iter_num in range(fl_iteration_num):
|
||||
if context.get_fl_context("ms_role") == "MS_WORKER":
|
||||
start_fl_job = StartFLJob(dataset.get_dataset_size() * args.client_batch_size)
|
||||
start_fl_job()
|
||||
if encrypt_type == "STABLE_PW_ENCRYPT":
|
||||
exchange_keys = ExchangeKeys()
|
||||
exchange_keys()
|
||||
get_keys = GetKeys()
|
||||
get_keys()
|
||||
|
||||
for _ in range(epoch):
|
||||
print("step is ", epoch, flush=True)
|
||||
model.train(1, dataset, callbacks=cbs, dataset_sink_mode=False)
|
||||
|
||||
if context.get_fl_context("ms_role") == "MS_WORKER":
|
||||
update_and_get_model = UpdateAndGetModel(net_opt.parameters)
|
||||
update_and_get_model()
|
||||
model.train(epoch, dataset, callbacks=cbs, dataset_sink_mode=False)
|
||||
|
||||
ckpt_name = user_id + "-fl-ms-bs32-" + str(iter_num) + "epoch.ckpt"
|
||||
ckpt_name = os.path.join(ckpt_path, ckpt_name)
|
||||
|
|
Loading…
Reference in New Issue