forked from mindspore-Ecosystem/mindspore
Fix cross-silo's st
This commit is contained in:
parent
15709e196c
commit
34dced6995
|
@ -27,6 +27,7 @@ class _StartFLJob(nn.Cell):
|
|||
"""
|
||||
StartFLJob for Federated Learning Worker.
|
||||
"""
|
||||
|
||||
def __init__(self, data_size):
|
||||
super(_StartFLJob, self).__init__()
|
||||
self.start_fl_job = P.StartFLJob(data_size)
|
||||
|
@ -40,9 +41,10 @@ class _UpdateAndGetModel(nn.Cell):
|
|||
"""
|
||||
Update and Get Model for Federated Learning Worker.
|
||||
"""
|
||||
def __init__(self, weights):
|
||||
|
||||
def __init__(self, weights, encrypt_type=""):
|
||||
super(_UpdateAndGetModel, self).__init__()
|
||||
self.update_model = P.UpdateModel()
|
||||
self.update_model = P.UpdateModel(encrypt_type)
|
||||
self.get_model = P.GetModel()
|
||||
self.weights = weights
|
||||
|
||||
|
@ -56,8 +58,9 @@ class _ExchangeKeys(nn.Cell):
|
|||
"""
|
||||
Exchange Keys for Stable PW Encrypt.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ExchangeKeys, self).__init__()
|
||||
super(_ExchangeKeys, self).__init__()
|
||||
self.exchange_keys = P.ExchangeKeys()
|
||||
|
||||
def construct(self):
|
||||
|
@ -68,8 +71,9 @@ class _GetKeys(nn.Cell):
|
|||
"""
|
||||
Get Keys for Stable PW Encrypt.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(GetKeys, self).__init__()
|
||||
super(_GetKeys, self).__init__()
|
||||
self.get_keys = P.GetKeys()
|
||||
|
||||
def construct(self):
|
||||
|
@ -110,6 +114,10 @@ class FederatedLearningManager(Callback):
|
|||
self._sync_type = sync_type
|
||||
self._global_step = 0
|
||||
self._data_size = 0
|
||||
self._encrypt_type = kwargs.get("encrypt_type", "NOT_ENCRYPT")
|
||||
if self._encrypt_type != "NOT_ENCRYPT" and self._encrypt_type != "STABLE_PW_ENCRYPT":
|
||||
raise ValueError(
|
||||
"encrypt_mode must be 'NOT_ENCRYPT' or 'STABLE_PW_ENCRYPT', but got {}.".format(self._encrypt_type))
|
||||
|
||||
if self._is_adaptive_sync():
|
||||
self._as_set_init_state(kwargs)
|
||||
|
@ -172,7 +180,7 @@ class FederatedLearningManager(Callback):
|
|||
abs_grads = dict()
|
||||
for param in self._model.trainable_params():
|
||||
if self._as_prefix not in param.name:
|
||||
abs_grads[self._as_prefix+param.name] = np.abs(param.asnumpy() - self._last_param[param.name])
|
||||
abs_grads[self._as_prefix + param.name] = np.abs(param.asnumpy() - self._last_param[param.name])
|
||||
for param in self._model.trainable_params():
|
||||
if self._as_prefix in param.name:
|
||||
param.set_data(Parameter(abs_grads[param.name]))
|
||||
|
@ -208,7 +216,7 @@ class FederatedLearningManager(Callback):
|
|||
if self._round_id - self._min_consistent_rate_at_round > self._observation_window_size:
|
||||
if self._sync_frequency > 1 and self._round_id > self._unchanged_round:
|
||||
self._sync_frequency = (self._sync_frequency + self._frequency_increase_ratio - 1) \
|
||||
// self._frequency_increase_ratio
|
||||
// self._frequency_increase_ratio
|
||||
self._min_consistent_rate = 1.1
|
||||
self._min_consistent_rate_at_round = self._round_id
|
||||
self._observation_window_size *= self._frequency_increase_ratio
|
||||
|
@ -235,9 +243,7 @@ class FederatedLearningManager(Callback):
|
|||
"""
|
||||
self._global_step += 1
|
||||
cb_params = run_context.original_args()
|
||||
inputs = cb_params.train_dataset_element
|
||||
batch_size = inputs[0].shape[0] if isinstance(inputs, (tuple, list)) else inputs.shape[0]
|
||||
self._data_size += batch_size
|
||||
self._data_size += cb_params.batch_num
|
||||
if context.get_fl_context("ms_role") == "MS_WORKER":
|
||||
if self._global_step == self._next_sync_iter_id:
|
||||
start_fl_job = _StartFLJob(self._data_size)
|
||||
|
@ -245,7 +251,15 @@ class FederatedLearningManager(Callback):
|
|||
self._data_size = 0
|
||||
if self._is_adaptive_sync():
|
||||
self._as_set_grads()
|
||||
update_and_get_model = _UpdateAndGetModel(ParameterTuple(self._model.trainable_params()))
|
||||
if self._encrypt_type == "STABLE_PW_ENCRYPT":
|
||||
exchange_keys = _ExchangeKeys()
|
||||
exchange_keys()
|
||||
get_keys = _GetKeys()
|
||||
get_keys()
|
||||
update_and_get_model = _UpdateAndGetModel(ParameterTuple(self._model.trainable_params()),
|
||||
self._encrypt_type)
|
||||
else:
|
||||
update_and_get_model = _UpdateAndGetModel(ParameterTuple(self._model.trainable_params()))
|
||||
update_and_get_model()
|
||||
self._next_sync_iter_id = self._global_step + self._sync_frequency
|
||||
if self._is_adaptive_sync():
|
||||
|
|
|
@ -29,6 +29,7 @@ parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
|
|||
parser.add_argument("--local_worker_num", type=int, default=-1)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--dataset_path", type=str, default="")
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
parser.add_argument("--sync_type", type=str, default="fixed", choices=["fixed", "adaptive"])
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
@ -44,6 +45,7 @@ worker_step_num_per_iteration = args.worker_step_num_per_iteration
|
|||
local_worker_num = args.local_worker_num
|
||||
config_file_path = args.config_file_path
|
||||
dataset_path = args.dataset_path
|
||||
encrypt_type = args.encrypt_type
|
||||
sync_type = args.sync_type
|
||||
|
||||
if local_worker_num == -1:
|
||||
|
@ -70,6 +72,7 @@ for i in range(local_worker_num):
|
|||
cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration)
|
||||
cmd_worker += " --dataset_path=" + str(dataset_path)
|
||||
cmd_worker += " --user_id=" + str(i)
|
||||
cmd_worker += " --encrypt_type=" + str(encrypt_type)
|
||||
cmd_worker += " --sync_type=" + sync_type
|
||||
cmd_worker += " > worker.log 2>&1 &"
|
||||
|
||||
|
|
|
@ -167,7 +167,8 @@ def train():
|
|||
federated_learning_manager = FederatedLearningManager(
|
||||
net,
|
||||
sync_frequency=config.client_epoch_num * dataset_size,
|
||||
sync_type=sync_type
|
||||
sync_type=sync_type,
|
||||
encrypt_type=encrypt_type,
|
||||
)
|
||||
opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,
|
||||
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
|
||||
|
|
|
@ -302,6 +302,7 @@ def train():
|
|||
network,
|
||||
sync_frequency=epoch * num_batches,
|
||||
sync_type=sync_type,
|
||||
encrypt_type=encrypt_type,
|
||||
)
|
||||
# define the optimizer
|
||||
net_opt = nn.Momentum(network.trainable_params(), client_learning_rate, 0.9)
|
||||
|
|
Loading…
Reference in New Issue