From 7fa258f17822a2388c28f2fad65ec0a5c0307549 Mon Sep 17 00:00:00 2001 From: ZPaC Date: Thu, 9 Sep 2021 14:42:40 +0800 Subject: [PATCH] Fix cross silo running issue. --- .../ccsrc/fl/server/iteration_metrics.cc | 7 ++- mindspore/ccsrc/fl/worker/fl_worker.cc | 12 +++-- .../ps/core/communicator/tcp_communicator.cc | 5 +- .../run_cross_silo_femnist_server.py | 22 +-------- .../run_cross_silo_femnist_worker.py | 6 +++ .../test_cross_silo_femnist.py | 37 +++++--------- .../run_cross_silo_lenet_server.py | 22 +-------- .../run_cross_silo_lenet_worker.py | 6 +++ tests/st/fl/cross_silo_lenet/src/model.py | 23 +++++++++ .../cross_silo_lenet/test_cross_silo_lenet.py | 48 ++++++++----------- 10 files changed, 83 insertions(+), 105 deletions(-) diff --git a/mindspore/ccsrc/fl/server/iteration_metrics.cc b/mindspore/ccsrc/fl/server/iteration_metrics.cc index 18d84383e14..502f21d17c3 100644 --- a/mindspore/ccsrc/fl/server/iteration_metrics.cc +++ b/mindspore/ccsrc/fl/server/iteration_metrics.cc @@ -27,8 +27,8 @@ bool IterationMetrics::Initialize() { config_ = std::make_unique(config_file_path_); MS_EXCEPTION_IF_NULL(config_); if (!config_->Initialize()) { - MS_LOG(EXCEPTION) << "Initializing for metrics failed. Config file path " << config_file_path_ - << " may be invalid or not exist."; + MS_LOG(WARNING) << "Initializing for metrics failed. Config file path " << config_file_path_ + << " may be invalid or not exist."; return false; } @@ -62,11 +62,13 @@ bool IterationMetrics::Initialize() { } metrics_file_.open(metrics_file_path_, std::ios::ate | std::ios::out); + metrics_file_.close(); } return true; } bool IterationMetrics::Summarize() { + metrics_file_.open(metrics_file_path_, std::ios::ate | std::ios::out); if (!metrics_file_.is_open()) { MS_LOG(ERROR) << "The metrics file is not opened."; return false; @@ -83,6 +85,7 @@ bool IterationMetrics::Summarize() { js_[kIterExecutionTime] = iteration_time_cost_; metrics_file_ << js_ << "\n"; (void)metrics_file_.flush(); + metrics_file_.close(); return true; } diff --git a/mindspore/ccsrc/fl/worker/fl_worker.cc b/mindspore/ccsrc/fl/worker/fl_worker.cc index 412d88133c4..836fc264bca 100644 --- a/mindspore/ccsrc/fl/worker/fl_worker.cc +++ b/mindspore/ccsrc/fl/worker/fl_worker.cc @@ -80,13 +80,17 @@ void FLWorker::Run() { } void FLWorker::Finalize() { - MS_EXCEPTION_IF_NULL(worker_node_); - if (!worker_node_->Finish()) { - MS_LOG(ERROR) << "Worker node finishing failed."; + if (worker_node_ == nullptr) { + MS_LOG(INFO) << "The worker is not initialized yet."; return; } + + // In some cases, worker calls the Finish function while other nodes don't. So timeout is acceptable. + if (!worker_node_->Finish()) { + MS_LOG(WARNING) << "Finishing worker node timeout."; + } if (!worker_node_->Stop()) { - MS_LOG(ERROR) << "Worker node stopping failed."; + MS_LOG(ERROR) << "Stopping worker node failed."; return; } } diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.cc b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.cc index 42bb0910c31..7cb534191e3 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.cc +++ b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.cc @@ -76,9 +76,10 @@ bool TcpCommunicator::Start() { bool TcpCommunicator::Stop() { MS_EXCEPTION_IF_NULL(abstrace_node_); + + // In some cases, server calls the Finish function while other nodes don't. So timeout is acceptable. if (!abstrace_node_->Finish()) { - MS_LOG(ERROR) << "Finishing server node failed."; - return false; + MS_LOG(WARNING) << "Finishing server node timeout."; } if (!abstrace_node_->Stop()) { MS_LOG(ERROR) << "Stopping server node failed."; diff --git a/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_server.py b/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_server.py index 9869b4df4dd..2eee76eaf86 100644 --- a/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_server.py +++ b/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_server.py @@ -32,18 +32,10 @@ parser.add_argument("--fl_name", type=str, default="Lenet") parser.add_argument("--fl_iteration_num", type=int, default=25) parser.add_argument("--client_epoch_num", type=int, default=20) parser.add_argument("--client_batch_size", type=int, default=32) -parser.add_argument("--client_learning_rate", type=float, default=0.1) +parser.add_argument("--client_learning_rate", type=float, default=0.01) parser.add_argument("--local_server_num", type=int, default=-1) parser.add_argument("--config_file_path", type=str, default="") parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") -# parameters for encrypt_type='DP_ENCRYPT' -parser.add_argument("--dp_eps", type=float, default=50.0) -parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num -parser.add_argument("--dp_norm_clip", type=float, default=1.0) -# parameters for encrypt_type='PW_ENCRYPT' -parser.add_argument("--share_secrets_ratio", type=float, default=1.0) -parser.add_argument("--cipher_time_window", type=int, default=300000) -parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) parser.add_argument("--dataset_path", type=str, default="") args, _ = parser.parse_known_args() @@ -66,12 +58,6 @@ client_learning_rate = args.client_learning_rate local_server_num = args.local_server_num config_file_path = args.config_file_path encrypt_type = args.encrypt_type -share_secrets_ratio = args.share_secrets_ratio -cipher_time_window = args.cipher_time_window -reconstruct_secrets_threshold = args.reconstruct_secrets_threshold -dp_eps = args.dp_eps -dp_delta = args.dp_delta -dp_norm_clip = args.dp_norm_clip dataset_path = args.dataset_path if local_server_num == -1: @@ -104,12 +90,6 @@ for i in range(local_server_num): cmd_server += " --client_batch_size=" + str(client_batch_size) cmd_server += " --client_learning_rate=" + str(client_learning_rate) cmd_server += " --encrypt_type=" + str(encrypt_type) - cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio) - cmd_server += " --cipher_time_window=" + str(cipher_time_window) - cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold) - cmd_server += " --dp_eps=" + str(dp_eps) - cmd_server += " --dp_delta=" + str(dp_delta) - cmd_server += " --dp_norm_clip=" + str(dp_norm_clip) cmd_server += " --dataset_path=" + str(dataset_path) cmd_server += " --user_id=" + str(0) cmd_server += " > server.log 2>&1 &" diff --git a/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_worker.py b/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_worker.py index 1be35aedb07..641c7f19c22 100644 --- a/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_worker.py +++ b/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_worker.py @@ -25,6 +25,8 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") parser.add_argument("--scheduler_port", type=int, default=8113) parser.add_argument("--fl_iteration_num", type=int, default=25) parser.add_argument("--client_epoch_num", type=int, default=20) +parser.add_argument("--client_batch_size", type=int, default=32) +parser.add_argument("--client_learning_rate", type=float, default=0.01) parser.add_argument("--worker_step_num_per_iteration", type=int, default=65) parser.add_argument("--local_worker_num", type=int, default=-1) parser.add_argument("--config_file_path", type=str, default="") @@ -39,6 +41,8 @@ scheduler_ip = args.scheduler_ip scheduler_port = args.scheduler_port fl_iteration_num = args.fl_iteration_num client_epoch_num = args.client_epoch_num +client_batch_size = args.client_batch_size +client_learning_rate = args.client_learning_rate worker_step_num_per_iteration = args.worker_step_num_per_iteration local_worker_num = args.local_worker_num config_file_path = args.config_file_path @@ -65,6 +69,8 @@ for i in range(local_worker_num): cmd_worker += " --config_file_path=" + str(config_file_path) cmd_worker += " --fl_iteration_num=" + str(fl_iteration_num) cmd_worker += " --client_epoch_num=" + str(client_epoch_num) + cmd_worker += " --client_batch_size=" + str(client_batch_size) + cmd_worker += " --client_learning_rate=" + str(client_learning_rate) cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration) cmd_worker += " --dataset_path=" + str(dataset_path) cmd_worker += " --user_id=" + str(i) diff --git a/tests/st/fl/cross_silo_femnist/test_cross_silo_femnist.py b/tests/st/fl/cross_silo_femnist/test_cross_silo_femnist.py index 84ca569332e..67e66c14c46 100644 --- a/tests/st/fl/cross_silo_femnist/test_cross_silo_femnist.py +++ b/tests/st/fl/cross_silo_femnist/test_cross_silo_femnist.py @@ -47,25 +47,22 @@ parser.add_argument("--start_fl_job_time_window", type=int, default=3000) parser.add_argument("--update_model_ratio", type=float, default=1.0) parser.add_argument("--update_model_time_window", type=int, default=3000) parser.add_argument("--fl_name", type=str, default="Lenet") +# fl_iteration_num is also used as the global epoch number for Worker. parser.add_argument("--fl_iteration_num", type=int, default=25) parser.add_argument("--client_epoch_num", type=int, default=20) +# client_batch_size is also used as the batch size of each mini-batch for Worker. parser.add_argument("--client_batch_size", type=int, default=32) -parser.add_argument("--client_learning_rate", type=float, default=0.1) +# client_learning_rate is also used as the learning rate for Worker. +parser.add_argument("--client_learning_rate", type=float, default=0.01) parser.add_argument("--worker_step_num_per_iteration", type=int, default=65) parser.add_argument("--scheduler_manage_port", type=int, default=11202) parser.add_argument("--config_file_path", type=str, default="") parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") -parser.add_argument("--dp_eps", type=float, default=50.0) -parser.add_argument("--dp_delta", type=float, default=0.01) -parser.add_argument("--dp_norm_clip", type=float, default=1.0) -parser.add_argument("--share_secrets_ratio", type=float, default=1.0) -parser.add_argument("--cipher_time_window", type=int, default=300000) -parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) parser.add_argument("--dataset_path", type=str, default="") +# The user_id is used to set each worker's dataset path. parser.add_argument("--user_id", type=str, default="0") parser.add_argument('--img_size', type=int, default=(32, 32, 1), help='the image size of (h,w,c)') -parser.add_argument('--batch_size', type=float, default=32, help='batch size') parser.add_argument('--repeat_size', type=int, default=1, help='the repeat size when create the dataLoader') args, _ = parser.parse_known_args() @@ -90,12 +87,6 @@ worker_step_num_per_iteration = args.worker_step_num_per_iteration scheduler_manage_port = args.scheduler_manage_port config_file_path = args.config_file_path encrypt_type = args.encrypt_type -share_secrets_ratio = args.share_secrets_ratio -cipher_time_window = args.cipher_time_window -reconstruct_secrets_threshold = args.reconstruct_secrets_threshold -dp_eps = args.dp_eps -dp_delta = args.dp_delta -dp_norm_clip = args.dp_norm_clip dataset_path = args.dataset_path user_id = args.user_id @@ -120,12 +111,6 @@ ctx = { "worker_step_num_per_iteration": worker_step_num_per_iteration, "scheduler_manage_port": scheduler_manage_port, "config_file_path": config_file_path, - "share_secrets_ratio": share_secrets_ratio, - "cipher_time_window": cipher_time_window, - "reconstruct_secrets_threshold": reconstruct_secrets_threshold, - "dp_eps": dp_eps, - "dp_delta": dp_delta, - "dp_norm_clip": dp_norm_clip, "encrypt_type": encrypt_type } @@ -314,13 +299,13 @@ class UpdateAndGetModel(nn.Cell): def train(): - epoch = client_epoch_num + epoch = fl_iteration_num network = LeNet5(62, 3) # define the loss function net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # define the optimizer - net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) + net_opt = nn.Momentum(network.trainable_params(), client_learning_rate, 0.9) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy(), 'Loss': nn.Loss()}) ds.config.set_seed(1) @@ -329,7 +314,7 @@ def train(): train_path = os.path.join(data_root_path, user, "train") test_path = os.path.join(data_root_path, user, "test") - dataset = create_dataset_from_folder(train_path, args.img_size, args.batch_size, args.repeat_size) + dataset = create_dataset_from_folder(train_path, args.img_size, args.client_batch_size, args.repeat_size) print("size is ", dataset.get_dataset_size(), flush=True) num_batches = dataset.get_dataset_size() @@ -341,7 +326,7 @@ def train(): for iter_num in range(fl_iteration_num): if context.get_fl_context("ms_role") == "MS_WORKER": - start_fl_job = StartFLJob(dataset.get_dataset_size() * args.batch_size) + start_fl_job = StartFLJob(dataset.get_dataset_size() * args.client_batch_size) start_fl_job() for _ in range(epoch): @@ -356,8 +341,8 @@ def train(): ckpt_name = os.path.join(ckpt_path, ckpt_name) save_checkpoint(network, ckpt_name) - train_acc, _ = evalute_process(model, train_path, args.img_size, args.batch_size) - test_acc, _ = evalute_process(model, test_path, args.img_size, args.batch_size) + train_acc, _ = evalute_process(model, train_path, args.img_size, args.client_batch_size) + test_acc, _ = evalute_process(model, test_path, args.img_size, args.client_batch_size) loss_list = loss_cb.get_loss() loss = sum(loss_list) / len(loss_list) print('local epoch: {}, loss: {}, trian acc: {}, test acc: {}'.format(iter_num, loss, train_acc, test_acc), diff --git a/tests/st/fl/cross_silo_lenet/run_cross_silo_lenet_server.py b/tests/st/fl/cross_silo_lenet/run_cross_silo_lenet_server.py index 81b33faae92..030087b47dc 100644 --- a/tests/st/fl/cross_silo_lenet/run_cross_silo_lenet_server.py +++ b/tests/st/fl/cross_silo_lenet/run_cross_silo_lenet_server.py @@ -32,18 +32,10 @@ parser.add_argument("--fl_name", type=str, default="Lenet") parser.add_argument("--fl_iteration_num", type=int, default=25) parser.add_argument("--client_epoch_num", type=int, default=20) parser.add_argument("--client_batch_size", type=int, default=32) -parser.add_argument("--client_learning_rate", type=float, default=0.1) +parser.add_argument("--client_learning_rate", type=float, default=0.01) parser.add_argument("--local_server_num", type=int, default=-1) parser.add_argument("--config_file_path", type=str, default="") parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") -# parameters for encrypt_type='DP_ENCRYPT' -parser.add_argument("--dp_eps", type=float, default=50.0) -parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num -parser.add_argument("--dp_norm_clip", type=float, default=1.0) -# parameters for encrypt_type='PW_ENCRYPT' -parser.add_argument("--share_secrets_ratio", type=float, default=1.0) -parser.add_argument("--cipher_time_window", type=int, default=300000) -parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) args, _ = parser.parse_known_args() device_target = args.device_target @@ -65,12 +57,6 @@ client_learning_rate = args.client_learning_rate local_server_num = args.local_server_num config_file_path = args.config_file_path encrypt_type = args.encrypt_type -share_secrets_ratio = args.share_secrets_ratio -cipher_time_window = args.cipher_time_window -reconstruct_secrets_threshold = args.reconstruct_secrets_threshold -dp_eps = args.dp_eps -dp_delta = args.dp_delta -dp_norm_clip = args.dp_norm_clip if local_server_num == -1: local_server_num = server_num @@ -102,12 +88,6 @@ for i in range(local_server_num): cmd_server += " --client_batch_size=" + str(client_batch_size) cmd_server += " --client_learning_rate=" + str(client_learning_rate) cmd_server += " --encrypt_type=" + str(encrypt_type) - cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio) - cmd_server += " --cipher_time_window=" + str(cipher_time_window) - cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold) - cmd_server += " --dp_eps=" + str(dp_eps) - cmd_server += " --dp_delta=" + str(dp_delta) - cmd_server += " --dp_norm_clip=" + str(dp_norm_clip) cmd_server += " > server.log 2>&1 &" import time diff --git a/tests/st/fl/cross_silo_lenet/run_cross_silo_lenet_worker.py b/tests/st/fl/cross_silo_lenet/run_cross_silo_lenet_worker.py index b2e14650bcf..d91fb2a1589 100644 --- a/tests/st/fl/cross_silo_lenet/run_cross_silo_lenet_worker.py +++ b/tests/st/fl/cross_silo_lenet/run_cross_silo_lenet_worker.py @@ -25,6 +25,8 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") parser.add_argument("--scheduler_port", type=int, default=8113) parser.add_argument("--fl_iteration_num", type=int, default=25) parser.add_argument("--client_epoch_num", type=int, default=20) +parser.add_argument("--client_batch_size", type=int, default=32) +parser.add_argument("--client_learning_rate", type=float, default=0.01) parser.add_argument("--worker_step_num_per_iteration", type=int, default=65) parser.add_argument("--local_worker_num", type=int, default=-1) parser.add_argument("--config_file_path", type=str, default="") @@ -38,6 +40,8 @@ scheduler_ip = args.scheduler_ip scheduler_port = args.scheduler_port fl_iteration_num = args.fl_iteration_num client_epoch_num = args.client_epoch_num +client_batch_size = args.client_batch_size +client_learning_rate = args.client_learning_rate worker_step_num_per_iteration = args.worker_step_num_per_iteration local_worker_num = args.local_worker_num config_file_path = args.config_file_path @@ -64,6 +68,8 @@ for i in range(local_worker_num): cmd_worker += " --config_file_path=" + str(config_file_path) cmd_worker += " --fl_iteration_num=" + str(fl_iteration_num) cmd_worker += " --client_epoch_num=" + str(client_epoch_num) + cmd_worker += " --client_batch_size=" + str(client_batch_size) + cmd_worker += " --client_learning_rate=" + str(client_learning_rate) cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration) cmd_worker += " > worker.log 2>&1 &" diff --git a/tests/st/fl/cross_silo_lenet/src/model.py b/tests/st/fl/cross_silo_lenet/src/model.py index aba4940499e..3c60b28550f 100644 --- a/tests/st/fl/cross_silo_lenet/src/model.py +++ b/tests/st/fl/cross_silo_lenet/src/model.py @@ -14,6 +14,7 @@ # ============================================================================ import mindspore.nn as nn +from mindspore.ops import operations as P from mindspore.common.initializer import TruncatedNormal def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): @@ -70,3 +71,25 @@ class LeNet5(nn.Cell): x = self.relu(x) x = self.fc3(x) return x + + +class StartFLJob(nn.Cell): + def __init__(self, data_size): + super(StartFLJob, self).__init__() + self.start_fl_job = P.StartFLJob(data_size) + + def construct(self): + return self.start_fl_job() + + +class UpdateAndGetModel(nn.Cell): + def __init__(self, weights): + super(UpdateAndGetModel, self).__init__() + self.update_model = P.UpdateModel() + self.get_model = P.GetModel() + self.weights = weights + + def construct(self): + self.update_model(self.weights) + get_model = self.get_model(self.weights) + return get_model diff --git a/tests/st/fl/cross_silo_lenet/test_cross_silo_lenet.py b/tests/st/fl/cross_silo_lenet/test_cross_silo_lenet.py index b2745c21e4e..96b22c9ccd3 100644 --- a/tests/st/fl/cross_silo_lenet/test_cross_silo_lenet.py +++ b/tests/st/fl/cross_silo_lenet/test_cross_silo_lenet.py @@ -19,9 +19,8 @@ import numpy as np import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor -from mindspore.nn import WithLossCell -from src.cell_wrapper import TrainOneStepCellForFLWorker -from src.model import LeNet5 +from mindspore.nn import WithLossCell, TrainOneStepCell +from src.model import LeNet5, StartFLJob, UpdateAndGetModel parser = argparse.ArgumentParser(description="test_cross_silo_lenet") parser.add_argument("--device_target", type=str, default="GPU") @@ -37,22 +36,17 @@ parser.add_argument("--start_fl_job_time_window", type=int, default=3000) parser.add_argument("--update_model_ratio", type=float, default=1.0) parser.add_argument("--update_model_time_window", type=int, default=3000) parser.add_argument("--fl_name", type=str, default="Lenet") +# fl_iteration_num is also used as the global epoch number for Worker. parser.add_argument("--fl_iteration_num", type=int, default=25) parser.add_argument("--client_epoch_num", type=int, default=20) +# client_batch_size is also used as the batch size of each mini-batch for Worker. parser.add_argument("--client_batch_size", type=int, default=32) -parser.add_argument("--client_learning_rate", type=float, default=0.1) +# client_learning_rate is also used as the learning rate for Worker. +parser.add_argument("--client_learning_rate", type=float, default=0.01) parser.add_argument("--worker_step_num_per_iteration", type=int, default=65) parser.add_argument("--scheduler_manage_port", type=int, default=11202) parser.add_argument("--config_file_path", type=str, default="") parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") -# parameters for encrypt_type='DP_ENCRYPT' -parser.add_argument("--dp_eps", type=float, default=50.0) -parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num -parser.add_argument("--dp_norm_clip", type=float, default=1.0) -# parameters for encrypt_type='PW_ENCRYPT' -parser.add_argument("--share_secrets_ratio", type=float, default=1.0) -parser.add_argument("--cipher_time_window", type=int, default=300000) -parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) args, _ = parser.parse_known_args() device_target = args.device_target @@ -76,12 +70,6 @@ worker_step_num_per_iteration = args.worker_step_num_per_iteration scheduler_manage_port = args.scheduler_manage_port config_file_path = args.config_file_path encrypt_type = args.encrypt_type -share_secrets_ratio = args.share_secrets_ratio -cipher_time_window = args.cipher_time_window -reconstruct_secrets_threshold = args.reconstruct_secrets_threshold -dp_eps = args.dp_eps -dp_delta = args.dp_delta -dp_norm_clip = args.dp_norm_clip ctx = { "enable_fl": True, @@ -104,12 +92,6 @@ ctx = { "worker_step_num_per_iteration": worker_step_num_per_iteration, "scheduler_manage_port": scheduler_manage_port, "config_file_path": config_file_path, - "share_secrets_ratio": share_secrets_ratio, - "cipher_time_window": cipher_time_window, - "reconstruct_secrets_threshold": reconstruct_secrets_threshold, - "dp_eps": dp_eps, - "dp_delta": dp_delta, - "dp_norm_clip": dp_norm_clip, "encrypt_type": encrypt_type } @@ -117,19 +99,27 @@ context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_g context.set_fl_context(**ctx) if __name__ == "__main__": - epoch = 50000 + epoch = fl_iteration_num np.random.seed(0) network = LeNet5(62) criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") - net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) + net_opt = nn.Momentum(network.trainable_params(), client_learning_rate, 0.9) net_with_criterion = WithLossCell(network, criterion) - train_network = TrainOneStepCellForFLWorker(net_with_criterion, net_opt) + train_network = TrainOneStepCell(net_with_criterion, net_opt) train_network.set_train() losses = [] for _ in range(epoch): - data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32)) - label = Tensor(np.random.randint(0, 61, (32)).astype(np.int32)) + if context.get_fl_context("ms_role") == "MS_WORKER": + start_fl_job = StartFLJob(dataset.get_dataset_size() * args.client_batch_size) + start_fl_job() + + data = Tensor(np.random.rand(client_batch_size, 3, 32, 32).astype(np.float32)) + label = Tensor(np.random.randint(0, 61, (client_batch_size)).astype(np.int32)) loss = train_network(data, label).asnumpy() losses.append(loss) + + if context.get_fl_context("ms_role") == "MS_WORKER": + update_and_get_model = UpdateAndGetModel(net_opt.parameters) + update_and_get_model() print(losses)