high level api for federated learning

This commit is contained in:
w00517672 2021-11-26 15:33:09 +08:00
parent 7587544c0e
commit ea3e651b36
11 changed files with 305 additions and 95 deletions

View File

@ -28,6 +28,8 @@ from ._time_monitor import TimeMonitor
from ._summary_collector import SummaryCollector
from ._lr_scheduler_callback import LearningRateScheduler
from ._landscape import SummaryLandscape
from ._fl_manager import FederatedLearningManager
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint",
"SummaryCollector", "CheckpointConfig", "RunContext", "LearningRateScheduler", "SummaryLandscape"]
"SummaryCollector", "CheckpointConfig", "RunContext", "LearningRateScheduler", "SummaryLandscape",
"FederatedLearningManager"]

View File

@ -0,0 +1,256 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FederatedLearningManager related class and functions."""
from copy import deepcopy
import numpy as np
from mindspore import context, nn
from mindspore.common import Parameter, ParameterTuple
from mindspore.train.callback import Callback
from mindspore.ops import operations as P
from mindspore._checkparam import Validator, Rel
class _StartFLJob(nn.Cell):
"""
StartFLJob for Federated Learning Worker.
"""
def __init__(self, data_size):
super(_StartFLJob, self).__init__()
self.start_fl_job = P.StartFLJob(data_size)
def construct(self):
succ = self.start_fl_job()
return succ
class _UpdateAndGetModel(nn.Cell):
"""
Update and Get Model for Federated Learning Worker.
"""
def __init__(self, weights):
super(_UpdateAndGetModel, self).__init__()
self.update_model = P.UpdateModel()
self.get_model = P.GetModel()
self.weights = weights
def construct(self):
self.update_model(self.weights)
succ = self.get_model(self.weights)
return succ
class _ExchangeKeys(nn.Cell):
"""
Exchange Keys for Stable PW Encrypt.
"""
def __init__(self):
super(ExchangeKeys, self).__init__()
self.exchange_keys = P.ExchangeKeys()
def construct(self):
return self.exchange_keys()
class _GetKeys(nn.Cell):
"""
Get Keys for Stable PW Encrypt.
"""
def __init__(self):
super(GetKeys, self).__init__()
self.get_keys = P.GetKeys()
def construct(self):
return self.get_keys()
class FederatedLearningManager(Callback):
"""
Manage Federated Learning during training.
Args:
model (nn.Cell): A training model.
sync_frequency (int): Synchronization frequency of parameters in Federated Learning.
Note:
In dataset sink mode, the unit of the frequency is the number of epochs.
Otherwise, the unit of the frequency is the number of steps.
sync_type (str): Parameter synchronization type in Federated Learning.
Supports ["fixed", "adaptive"]. Default: "fixed".
- fixed: The frequency of parameter synchronization is fixed.
- adaptive: The frequency of parameter synchronization changes adaptively.
Note:
This is an experimental prototype that is subject to change.
"""
def __init__(self, model, sync_frequency, sync_type='fixed', **kwargs):
super(FederatedLearningManager, self).__init__()
server_mode = context.get_fl_context("server_mode")
if server_mode not in ("FEDERATED_LEARNING", "HYBRID_TRAINING"):
raise ValueError("server_mode must in (\"FEDERATED_LEARNING\", \"HYBRID_TRAINING\")")
Validator.check_isinstance('model', model, nn.Cell)
Validator.check_positive_int(sync_frequency)
Validator.check_string(sync_type, ["fixed", "adaptive"])
self._model = model
self._sync_frequency = sync_frequency
self._next_sync_iter_id = self._sync_frequency
self._sync_type = sync_type
self._global_step = 0
self._data_size = 0
if self._is_adaptive_sync():
self._as_set_init_state(kwargs)
self._as_wrap_cell()
def _is_adaptive_sync(self):
"""
Determine whether adaptive frequency synchronization is required.
"""
return self._sync_type == "adaptive"
def _as_set_init_state(self, kwargs):
"""
Setting the initial state for adaptive synchronization.
"""
self._as_prefix = "as_abs_grad."
self._min_consistent_rate = kwargs.get("min_consistent_rate", 1.1)
Validator.check_non_negative_float(self._min_consistent_rate)
self._min_consistent_rate_at_round = kwargs.get("min_consistent_rate_at_round", 0)
Validator.check_non_negative_int(self._min_consistent_rate_at_round)
self._ema_alpha = kwargs.get("ema_alpha", 0.5)
Validator.check_float_range(self._ema_alpha, 0.0, 1.0, Rel.INC_NEITHER)
self._observation_window_size = kwargs.get("observation_window_size", 5)
Validator.check_positive_int(self._observation_window_size)
self._frequency_increase_ratio = kwargs.get("frequency_increase_ratio", 2)
Validator.check_positive_int(self._frequency_increase_ratio)
self._unchanged_round = kwargs.get("unchanged_round", 0)
Validator.check_non_negative_int(self._unchanged_round)
self._round_id = 0
self._last_param = {_.name: deepcopy(_.asnumpy()) for _ in self._model.trainable_params()
if self._as_prefix not in _.name}
self._model_size = 0
self._grads_ema = dict()
self._abs_grads_ema = dict()
for param in self._model.trainable_params():
if self._as_prefix not in param.name:
self._model_size += np.product(param.shape)
self._grads_ema[param.name] = np.zeros(param.shape)
self._abs_grads_ema[param.name] = np.zeros(param.shape)
self._model_size = float(self._model_size)
def _as_wrap_cell(self):
"""
Wrap Cell for adaptive synchronization.
"""
param_list = list()
for param in self._model.trainable_params():
new_param = param.clone()
new_param.name = self._as_prefix + param.name
param_list.append(new_param)
for param in param_list:
self._model.insert_param_to_cell(param.name, param, False)
def _as_set_grads(self):
"""
Set the absolute value of the gradient for adaptive synchronization.
"""
abs_grads = dict()
for param in self._model.trainable_params():
if self._as_prefix not in param.name:
abs_grads[self._as_prefix+param.name] = np.abs(param.asnumpy() - self._last_param[param.name])
for param in self._model.trainable_params():
if self._as_prefix in param.name:
param.set_data(Parameter(abs_grads[param.name]))
def _as_analyze_gradient(self):
"""
Analysis of relevant statistics based on gradient for adaptive synchronization.
"""
worker_num = context.get_fl_context("worker_num")
ema_alpha = self._ema_alpha
consistent_rate_sum = 0.0
grads = dict()
abs_grads = dict()
for param in self._model.trainable_params():
if self._as_prefix in param.name:
abs_grads[param.name.replace(self._as_prefix, '')] = param.asnumpy() * worker_num
else:
grads[param.name] = (param.asnumpy() - self._last_param[param.name]) * worker_num
for last_p in self._last_param:
self._grads_ema[last_p] = ema_alpha * self._grads_ema[last_p] + (1 - ema_alpha) * grads[last_p]
self._abs_grads_ema[last_p] = ema_alpha * self._abs_grads_ema[last_p] + (1 - ema_alpha) * abs_grads[last_p]
divide_base = np.where(self._abs_grads_ema[last_p] == 0,
np.ones(self._abs_grads_ema[last_p].shape), self._abs_grads_ema[last_p])
layer_consistent_rate = np.abs(self._grads_ema[last_p]) / divide_base
consistent_rate_sum += np.sum(layer_consistent_rate)
consistent_rate = float(consistent_rate_sum / self._model_size)
if self._min_consistent_rate > consistent_rate:
self._min_consistent_rate = consistent_rate
self._min_consistent_rate_at_round = self._round_id
else:
if self._round_id - self._min_consistent_rate_at_round > self._observation_window_size:
if self._sync_frequency > 1 and self._round_id > self._unchanged_round:
self._sync_frequency = (self._sync_frequency + self._frequency_increase_ratio - 1) \
// self._frequency_increase_ratio
self._min_consistent_rate = 1.1
self._min_consistent_rate_at_round = self._round_id
self._observation_window_size *= self._frequency_increase_ratio
for param in self._model.trainable_params():
if self._as_prefix not in param.name:
self._grads_ema[param.name] = np.zeros(param.shape)
self._abs_grads_ema[param.name] = np.zeros(param.shape)
def _as_set_last_param(self):
"""
Set the value of last parameters for adaptive synchronization.
"""
self._last_param = {_.name: deepcopy(_.asnumpy()) for _ in self._model.trainable_params()
if self._as_prefix not in _.name}
def step_end(self, run_context):
"""
Synchronization parameters at the end of step. If sync_type is "adaptive", the synchronous frequency is
adaptively adjusted here.
Args:
run_context (RunContext): Context of the train running.
"""
self._global_step += 1
cb_params = run_context.original_args()
inputs = cb_params.train_dataset_element
batch_size = inputs[0].shape[0] if isinstance(inputs, (tuple, list)) else inputs.shape[0]
self._data_size += batch_size
if context.get_fl_context("ms_role") == "MS_WORKER":
if self._global_step == self._next_sync_iter_id:
start_fl_job = _StartFLJob(self._data_size)
start_fl_job()
self._data_size = 0
if self._is_adaptive_sync():
self._as_set_grads()
update_and_get_model = _UpdateAndGetModel(ParameterTuple(self._model.trainable_params()))
update_and_get_model()
self._next_sync_iter_id = self._global_step + self._sync_frequency
if self._is_adaptive_sync():
self._as_analyze_gradient()
self._round_id += 1
self._as_set_last_param()
print("sync step is: {}".format(self._global_step))

View File

@ -150,6 +150,7 @@ config_file_path: ""
encrypt_type: "NOT_ENCRYPT"
dataset_path: ""
user_id: 0
sync_type: fixed
# Number of threads used to process the dataset in parallel

View File

@ -26,6 +26,7 @@ parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--dataset_path", type=str, default="")
parser.add_argument("--sync_type", type=str, default="fixed", choices=["fixed", "adaptive"])
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -37,6 +38,7 @@ scheduler_port = args.scheduler_port
scheduler_manage_port = args.scheduler_manage_port
config_file_path = args.config_file_path
dataset_path = args.dataset_path
sync_type = args.sync_type
cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&"
cmd_sched += "mkdir ${execute_path}/scheduler/ &&"
@ -53,6 +55,7 @@ cmd_sched += " --scheduler_port=" + str(scheduler_port)
cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port)
cmd_sched += " --dataset_path=" + str(dataset_path)
cmd_sched += " --user_id=" + str(0)
cmd_sched += " --sync_type=" + sync_type
cmd_sched += " > scheduler.log 2>&1 &"
subprocess.call(['bash', '-c', cmd_sched])

View File

@ -36,6 +36,7 @@ parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--local_server_num", type=int, default=-1)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
parser.add_argument("--sync_type", type=str, default="fixed", choices=["fixed", "adaptive"])
parser.add_argument("--dataset_path", type=str, default="")
@ -59,6 +60,7 @@ client_learning_rate = args.client_learning_rate
local_server_num = args.local_server_num
config_file_path = args.config_file_path
encrypt_type = args.encrypt_type
sync_type = args.sync_type
dataset_path = args.dataset_path
@ -94,6 +96,7 @@ for i in range(local_server_num):
cmd_server += " --encrypt_type=" + str(encrypt_type)
cmd_server += " --dataset_path=" + str(dataset_path)
cmd_server += " --user_id=" + str(0)
cmd_server += " --sync_type=" + sync_type
cmd_server += " > server.log 2>&1 &"

View File

@ -29,6 +29,7 @@ parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
parser.add_argument("--local_worker_num", type=int, default=-1)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--dataset_path", type=str, default="")
parser.add_argument("--sync_type", type=str, default="fixed", choices=["fixed", "adaptive"])
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -43,6 +44,7 @@ worker_step_num_per_iteration = args.worker_step_num_per_iteration
local_worker_num = args.local_worker_num
config_file_path = args.config_file_path
dataset_path = args.dataset_path
sync_type = args.sync_type
if local_worker_num == -1:
local_worker_num = worker_num
@ -68,6 +70,7 @@ for i in range(local_worker_num):
cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration)
cmd_worker += " --dataset_path=" + str(dataset_path)
cmd_worker += " --user_id=" + str(i)
cmd_worker += " --sync_type=" + sync_type
cmd_worker += " > worker.log 2>&1 &"
subprocess.call(['bash', '-c', cmd_worker])

View File

@ -21,9 +21,7 @@ import numpy as np
import mindspore.common.dtype as mstype
from mindspore import context, Tensor, Parameter
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.train.callback import TimeMonitor
from mindspore.train.callback import TimeMonitor, FederatedLearningManager
from mindspore.train import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.serialization import save_checkpoint
@ -59,6 +57,7 @@ worker_step_num_per_iteration = config.worker_step_num_per_iteration
scheduler_manage_port = config.scheduler_manage_port
config_file_path = config.config_file_path
encrypt_type = config.encrypt_type
sync_type = config.sync_type
user_id = config.user_id
@ -125,25 +124,6 @@ def train_fasterrcnn_():
return dataset_size, dataset
class StartFLJob(nn.Cell):
def __init__(self, data_size):
super(StartFLJob, self).__init__()
self.start_fl_job = P.StartFLJob(data_size)
def construct(self):
return self.start_fl_job()
class UpdateAndGetModel(nn.Cell):
def __init__(self, weights):
super(UpdateAndGetModel, self).__init__()
self.update_model = P.UpdateModel()
self.get_model = P.GetModel()
self.weights = weights
def construct(self):
self.update_model(self.weights)
get_model = self.get_model(self.weights)
return get_model
def train():
""" train_fasterrcnn """
@ -184,6 +164,11 @@ def train():
loss = LossNet()
lr = Tensor(dynamic_lr(config, dataset_size), mstype.float32)
federated_learning_manager = FederatedLearningManager(
net,
sync_frequency=config.client_epoch_num * dataset_size,
sync_type=sync_type
)
opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
net_with_loss = WithLossCell(net, loss)
@ -194,7 +179,7 @@ def train():
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale)
time_cb = TimeMonitor(data_size=dataset_size)
loss_cb = LossCallBack(rank_id=rank)
cb = [time_cb, loss_cb]
cb = [federated_learning_manager, time_cb, loss_cb]
model = Model(net)
ckpt_path1 = os.path.join("ckpt", user)
@ -202,13 +187,7 @@ def train():
os.makedirs(ckpt_path1)
print("====================", config.client_epoch_num, fl_iteration_num, flush=True)
for iter_num in range(fl_iteration_num):
if context.get_fl_context("ms_role") == "MS_WORKER":
start_fl_job = StartFLJob(dataset_size * config.batch_size)
start_fl_job()
model.train(config.client_epoch_num, dataset, callbacks=cb)
if context.get_fl_context("ms_role") == "MS_WORKER":
update_and_get_model = UpdateAndGetModel(opt.parameters)
update_and_get_model()
ckpt_name = user + "-fast-rcnn-" + str(iter_num) + "epoch.ckpt"
ckpt_path = os.path.join(ckpt_path1, ckpt_name)
save_checkpoint(net, ckpt_path)

View File

@ -26,6 +26,7 @@ parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--dataset_path", type=str, default="")
parser.add_argument("--sync_type", type=str, default="fixed", choices=["fixed", "adaptive"])
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -37,6 +38,7 @@ scheduler_port = args.scheduler_port
scheduler_manage_port = args.scheduler_manage_port
config_file_path = args.config_file_path
dataset_path = args.dataset_path
sync_type = args.sync_type
cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&"
cmd_sched += "mkdir ${execute_path}/scheduler/ &&"
@ -53,6 +55,7 @@ cmd_sched += " --scheduler_port=" + str(scheduler_port)
cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port)
cmd_sched += " --dataset_path=" + str(dataset_path)
cmd_sched += " --user_id=" + str(0)
cmd_sched += " --sync_type=" + sync_type
cmd_sched += " > scheduler.log 2>&1 &"
subprocess.call(['bash', '-c', cmd_sched])

View File

@ -37,6 +37,7 @@ parser.add_argument("--local_server_num", type=int, default=-1)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
parser.add_argument("--dataset_path", type=str, default="")
parser.add_argument("--sync_type", type=str, default="fixed", choices=["fixed", "adaptive"])
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -59,6 +60,7 @@ local_server_num = args.local_server_num
config_file_path = args.config_file_path
encrypt_type = args.encrypt_type
dataset_path = args.dataset_path
sync_type = args.sync_type
if local_server_num == -1:
local_server_num = server_num
@ -92,6 +94,7 @@ for i in range(local_server_num):
cmd_server += " --encrypt_type=" + str(encrypt_type)
cmd_server += " --dataset_path=" + str(dataset_path)
cmd_server += " --user_id=" + str(0)
cmd_server += " --sync_type=" + sync_type
cmd_server += " > server.log 2>&1 &"
import time

View File

@ -32,6 +32,7 @@ parser.add_argument("--local_worker_num", type=int, default=-1)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--dataset_path", type=str, default="")
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
parser.add_argument("--sync_type", type=str, default="fixed", choices=["fixed", "adaptive"])
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -49,6 +50,7 @@ local_worker_num = args.local_worker_num
config_file_path = args.config_file_path
dataset_path = args.dataset_path
encrypt_type = args.encrypt_type
sync_type = args.sync_type
if local_worker_num == -1:
local_worker_num = worker_num
@ -77,6 +79,7 @@ for i in range(local_worker_num):
cmd_worker += " --dataset_path=" + str(dataset_path)
cmd_worker += " --encrypt_type=" + str(encrypt_type)
cmd_worker += " --user_id=" + str(i)
cmd_worker += " --sync_type=" + sync_type
cmd_worker += " > worker.log 2>&1 &"
subprocess.call(['bash', '-c', cmd_worker])

View File

@ -28,8 +28,7 @@ import mindspore.dataset.vision.py_transforms as PV
import mindspore.dataset.transforms.py_transforms as PT
import mindspore.dataset.transforms.c_transforms as tC
from mindspore.train.serialization import save_checkpoint
from mindspore.ops import operations as P
from mindspore.train.callback import Callback
from mindspore.train.callback import Callback, FederatedLearningManager
from mindspore.nn.metrics import Accuracy
from mindspore.train import Model
@ -62,6 +61,7 @@ parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--dataset_path", type=str, default="")
# The user_id is used to set each worker's dataset path.
parser.add_argument("--user_id", type=str, default="0")
parser.add_argument("--sync_type", type=str, default="fixed", choices=["fixed", "adaptive"])
parser.add_argument('--img_size', type=int, default=(32, 32, 1), help='the image size of (h,w,c)')
parser.add_argument('--repeat_size', type=int, default=1, help='the repeat size when create the dataLoader')
@ -91,6 +91,7 @@ encrypt_type = args.encrypt_type
cipher_time_window = args.cipher_time_window
dataset_path = args.dataset_path
user_id = args.user_id
sync_type = args.sync_type
ctx = {
"enable_fl": True,
@ -279,56 +280,11 @@ def evalute_process(model, eval_data, img_size, batch_size):
return acc['Accuracy'], acc['Loss']
class StartFLJob(nn.Cell):
def __init__(self, data_size):
super(StartFLJob, self).__init__()
self.start_fl_job = P.StartFLJob(data_size)
def construct(self):
return self.start_fl_job()
class ExchangeKeys(nn.Cell):
def __init__(self):
super(ExchangeKeys, self).__init__()
self.exchange_keys = P.ExchangeKeys()
def construct(self):
return self.exchange_keys()
class GetKeys(nn.Cell):
def __init__(self):
super(GetKeys, self).__init__()
self.get_keys = P.GetKeys()
def construct(self):
return self.get_keys()
class UpdateAndGetModel(nn.Cell):
def __init__(self, weights):
super(UpdateAndGetModel, self).__init__()
self.update_model = P.UpdateModel()
self.get_model = P.GetModel()
self.weights = weights
def construct(self):
self.update_model(self.weights)
get_model = self.get_model(self.weights)
return get_model
def train():
epoch = fl_iteration_num
network = LeNet5(62, 3)
# define the loss function
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define the optimizer
net_opt = nn.Momentum(network.trainable_params(), client_learning_rate, 0.9)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy(), 'Loss': nn.Loss()})
# construct dataset
ds.config.set_seed(1)
data_root_path = dataset_path
user = "dataset_" + user_id
@ -339,29 +295,27 @@ def train():
print("size is ", dataset.get_dataset_size(), flush=True)
num_batches = dataset.get_dataset_size()
# define the loss function
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define fl manager
federated_learning_manager = FederatedLearningManager(
network,
sync_frequency=epoch * num_batches,
sync_type=sync_type,
)
# define the optimizer
net_opt = nn.Momentum(network.trainable_params(), client_learning_rate, 0.9)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy(), 'Loss': nn.Loss()})
loss_cb = LossGet(1, num_batches)
cbs = []
cbs = list()
cbs.append(federated_learning_manager)
cbs.append(loss_cb)
ckpt_path = "ckpt"
os.makedirs(ckpt_path)
for iter_num in range(fl_iteration_num):
if context.get_fl_context("ms_role") == "MS_WORKER":
start_fl_job = StartFLJob(dataset.get_dataset_size() * args.client_batch_size)
start_fl_job()
if encrypt_type == "STABLE_PW_ENCRYPT":
exchange_keys = ExchangeKeys()
exchange_keys()
get_keys = GetKeys()
get_keys()
for _ in range(epoch):
print("step is ", epoch, flush=True)
model.train(1, dataset, callbacks=cbs, dataset_sink_mode=False)
if context.get_fl_context("ms_role") == "MS_WORKER":
update_and_get_model = UpdateAndGetModel(net_opt.parameters)
update_and_get_model()
model.train(epoch, dataset, callbacks=cbs, dataset_sink_mode=False)
ckpt_name = user_id + "-fl-ms-bs32-" + str(iter_num) + "epoch.ckpt"
ckpt_name = os.path.join(ckpt_path, ckpt_name)