Fix cross-silo's st

This commit is contained in:
jin-xiulang 2021-12-07 20:26:33 +08:00
parent 15709e196c
commit 34dced6995
4 changed files with 30 additions and 11 deletions

View File

@ -27,6 +27,7 @@ class _StartFLJob(nn.Cell):
"""
StartFLJob for Federated Learning Worker.
"""
def __init__(self, data_size):
super(_StartFLJob, self).__init__()
self.start_fl_job = P.StartFLJob(data_size)
@ -40,9 +41,10 @@ class _UpdateAndGetModel(nn.Cell):
"""
Update and Get Model for Federated Learning Worker.
"""
def __init__(self, weights):
def __init__(self, weights, encrypt_type=""):
super(_UpdateAndGetModel, self).__init__()
self.update_model = P.UpdateModel()
self.update_model = P.UpdateModel(encrypt_type)
self.get_model = P.GetModel()
self.weights = weights
@ -56,8 +58,9 @@ class _ExchangeKeys(nn.Cell):
"""
Exchange Keys for Stable PW Encrypt.
"""
def __init__(self):
super(ExchangeKeys, self).__init__()
super(_ExchangeKeys, self).__init__()
self.exchange_keys = P.ExchangeKeys()
def construct(self):
@ -68,8 +71,9 @@ class _GetKeys(nn.Cell):
"""
Get Keys for Stable PW Encrypt.
"""
def __init__(self):
super(GetKeys, self).__init__()
super(_GetKeys, self).__init__()
self.get_keys = P.GetKeys()
def construct(self):
@ -110,6 +114,10 @@ class FederatedLearningManager(Callback):
self._sync_type = sync_type
self._global_step = 0
self._data_size = 0
self._encrypt_type = kwargs.get("encrypt_type", "NOT_ENCRYPT")
if self._encrypt_type != "NOT_ENCRYPT" and self._encrypt_type != "STABLE_PW_ENCRYPT":
raise ValueError(
"encrypt_mode must be 'NOT_ENCRYPT' or 'STABLE_PW_ENCRYPT', but got {}.".format(self._encrypt_type))
if self._is_adaptive_sync():
self._as_set_init_state(kwargs)
@ -172,7 +180,7 @@ class FederatedLearningManager(Callback):
abs_grads = dict()
for param in self._model.trainable_params():
if self._as_prefix not in param.name:
abs_grads[self._as_prefix+param.name] = np.abs(param.asnumpy() - self._last_param[param.name])
abs_grads[self._as_prefix + param.name] = np.abs(param.asnumpy() - self._last_param[param.name])
for param in self._model.trainable_params():
if self._as_prefix in param.name:
param.set_data(Parameter(abs_grads[param.name]))
@ -208,7 +216,7 @@ class FederatedLearningManager(Callback):
if self._round_id - self._min_consistent_rate_at_round > self._observation_window_size:
if self._sync_frequency > 1 and self._round_id > self._unchanged_round:
self._sync_frequency = (self._sync_frequency + self._frequency_increase_ratio - 1) \
// self._frequency_increase_ratio
// self._frequency_increase_ratio
self._min_consistent_rate = 1.1
self._min_consistent_rate_at_round = self._round_id
self._observation_window_size *= self._frequency_increase_ratio
@ -235,9 +243,7 @@ class FederatedLearningManager(Callback):
"""
self._global_step += 1
cb_params = run_context.original_args()
inputs = cb_params.train_dataset_element
batch_size = inputs[0].shape[0] if isinstance(inputs, (tuple, list)) else inputs.shape[0]
self._data_size += batch_size
self._data_size += cb_params.batch_num
if context.get_fl_context("ms_role") == "MS_WORKER":
if self._global_step == self._next_sync_iter_id:
start_fl_job = _StartFLJob(self._data_size)
@ -245,7 +251,15 @@ class FederatedLearningManager(Callback):
self._data_size = 0
if self._is_adaptive_sync():
self._as_set_grads()
update_and_get_model = _UpdateAndGetModel(ParameterTuple(self._model.trainable_params()))
if self._encrypt_type == "STABLE_PW_ENCRYPT":
exchange_keys = _ExchangeKeys()
exchange_keys()
get_keys = _GetKeys()
get_keys()
update_and_get_model = _UpdateAndGetModel(ParameterTuple(self._model.trainable_params()),
self._encrypt_type)
else:
update_and_get_model = _UpdateAndGetModel(ParameterTuple(self._model.trainable_params()))
update_and_get_model()
self._next_sync_iter_id = self._global_step + self._sync_frequency
if self._is_adaptive_sync():

View File

@ -29,6 +29,7 @@ parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
parser.add_argument("--local_worker_num", type=int, default=-1)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--dataset_path", type=str, default="")
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
parser.add_argument("--sync_type", type=str, default="fixed", choices=["fixed", "adaptive"])
args, _ = parser.parse_known_args()
@ -44,6 +45,7 @@ worker_step_num_per_iteration = args.worker_step_num_per_iteration
local_worker_num = args.local_worker_num
config_file_path = args.config_file_path
dataset_path = args.dataset_path
encrypt_type = args.encrypt_type
sync_type = args.sync_type
if local_worker_num == -1:
@ -70,6 +72,7 @@ for i in range(local_worker_num):
cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration)
cmd_worker += " --dataset_path=" + str(dataset_path)
cmd_worker += " --user_id=" + str(i)
cmd_worker += " --encrypt_type=" + str(encrypt_type)
cmd_worker += " --sync_type=" + sync_type
cmd_worker += " > worker.log 2>&1 &"

View File

@ -167,7 +167,8 @@ def train():
federated_learning_manager = FederatedLearningManager(
net,
sync_frequency=config.client_epoch_num * dataset_size,
sync_type=sync_type
sync_type=sync_type,
encrypt_type=encrypt_type,
)
opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,
weight_decay=config.weight_decay, loss_scale=config.loss_scale)

View File

@ -302,6 +302,7 @@ def train():
network,
sync_frequency=epoch * num_batches,
sync_type=sync_type,
encrypt_type=encrypt_type,
)
# define the optimizer
net_opt = nn.Momentum(network.trainable_params(), client_learning_rate, 0.9)