diff --git a/tests/st/fl/cross_silo_lenet/config.json b/tests/st/fl/cross_silo_lenet/config.json new file mode 100644 index 00000000000..37ac6edfb25 --- /dev/null +++ b/tests/st/fl/cross_silo_lenet/config.json @@ -0,0 +1,6 @@ +{ + "recovery": { + "storge_type": 1, + "storage_file_path": "recovery.json" + } +} \ No newline at end of file diff --git a/tests/st/fl/cross_silo_lenet/finish_cross_silo_lenet.py b/tests/st/fl/cross_silo_lenet/finish_cross_silo_lenet.py new file mode 100644 index 00000000000..489a8671961 --- /dev/null +++ b/tests/st/fl/cross_silo_lenet/finish_cross_silo_lenet.py @@ -0,0 +1,29 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import subprocess + +parser = argparse.ArgumentParser(description="Finish test_cross_silo_lenet.py case") +parser.add_argument("--scheduler_port", type=int, default=8113) + +args, _ = parser.parse_known_args() +scheduler_port = args.scheduler_port + +cmd = "pid=`ps -ef|grep \"scheduler_port=" + str(scheduler_port) + "\" " +cmd += " | grep -v \"grep\" | grep -v \"finish\" |awk '{print $2}'` && " +cmd += "for id in $pid; do kill -9 $id && echo \"killed $id\"; done" + +subprocess.call(['bash', '-c', cmd]) diff --git a/tests/st/fl/cross_silo_lenet/run_cross_silo_lenet_sched.py b/tests/st/fl/cross_silo_lenet/run_cross_silo_lenet_sched.py new file mode 100644 index 00000000000..48e79094476 --- /dev/null +++ b/tests/st/fl/cross_silo_lenet/run_cross_silo_lenet_sched.py @@ -0,0 +1,54 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import subprocess + +parser = argparse.ArgumentParser(description="Run test_cross_silo_lenet.py case") +parser.add_argument("--device_target", type=str, default="CPU") +parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING") +parser.add_argument("--worker_num", type=int, default=1) +parser.add_argument("--server_num", type=int, default=2) +parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") +parser.add_argument("--scheduler_port", type=int, default=8113) +parser.add_argument("--scheduler_manage_port", type=int, default=11202) +parser.add_argument("--config_file_path", type=str, default="") + +args, _ = parser.parse_known_args() +device_target = args.device_target +server_mode = args.server_mode +worker_num = args.worker_num +server_num = args.server_num +scheduler_ip = args.scheduler_ip +scheduler_port = args.scheduler_port +scheduler_manage_port = args.scheduler_manage_port +config_file_path = args.config_file_path + +cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&" +cmd_sched += "mkdir ${execute_path}/scheduler/ &&" +cmd_sched += "cd ${execute_path}/scheduler/ || exit && export GLOG_v=1 &&" +cmd_sched += "python ${self_path}/../test_cross_silo_lenet.py" +cmd_sched += " --device_target=" + device_target +cmd_sched += " --server_mode=" + server_mode +cmd_sched += " --ms_role=MS_SCHED" +cmd_sched += " --worker_num=" + str(worker_num) +cmd_sched += " --server_num=" + str(server_num) +cmd_sched += " --config_file_path=" + str(config_file_path) +cmd_sched += " --scheduler_ip=" + scheduler_ip +cmd_sched += " --scheduler_port=" + str(scheduler_port) +cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port) +cmd_sched += " > scheduler.log 2>&1 &" + +subprocess.call(['bash', '-c', cmd_sched]) diff --git a/tests/st/fl/cross_silo_lenet/run_cross_silo_lenet_server.py b/tests/st/fl/cross_silo_lenet/run_cross_silo_lenet_server.py new file mode 100644 index 00000000000..81b33faae92 --- /dev/null +++ b/tests/st/fl/cross_silo_lenet/run_cross_silo_lenet_server.py @@ -0,0 +1,115 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import subprocess + +parser = argparse.ArgumentParser(description="Run test_cross_silo_lenet.py case") +parser.add_argument("--device_target", type=str, default="CPU") +parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING") +parser.add_argument("--worker_num", type=int, default=1) +parser.add_argument("--server_num", type=int, default=2) +parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") +parser.add_argument("--scheduler_port", type=int, default=8113) +parser.add_argument("--fl_server_port", type=int, default=6666) +parser.add_argument("--start_fl_job_threshold", type=int, default=1) +parser.add_argument("--start_fl_job_time_window", type=int, default=3000) +parser.add_argument("--update_model_ratio", type=float, default=1.0) +parser.add_argument("--update_model_time_window", type=int, default=3000) +parser.add_argument("--fl_name", type=str, default="Lenet") +parser.add_argument("--fl_iteration_num", type=int, default=25) +parser.add_argument("--client_epoch_num", type=int, default=20) +parser.add_argument("--client_batch_size", type=int, default=32) +parser.add_argument("--client_learning_rate", type=float, default=0.1) +parser.add_argument("--local_server_num", type=int, default=-1) +parser.add_argument("--config_file_path", type=str, default="") +parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") +# parameters for encrypt_type='DP_ENCRYPT' +parser.add_argument("--dp_eps", type=float, default=50.0) +parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num +parser.add_argument("--dp_norm_clip", type=float, default=1.0) +# parameters for encrypt_type='PW_ENCRYPT' +parser.add_argument("--share_secrets_ratio", type=float, default=1.0) +parser.add_argument("--cipher_time_window", type=int, default=300000) +parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) + +args, _ = parser.parse_known_args() +device_target = args.device_target +server_mode = args.server_mode +worker_num = args.worker_num +server_num = args.server_num +scheduler_ip = args.scheduler_ip +scheduler_port = args.scheduler_port +fl_server_port = args.fl_server_port +start_fl_job_threshold = args.start_fl_job_threshold +start_fl_job_time_window = args.start_fl_job_time_window +update_model_ratio = args.update_model_ratio +update_model_time_window = args.update_model_time_window +fl_name = args.fl_name +fl_iteration_num = args.fl_iteration_num +client_epoch_num = args.client_epoch_num +client_batch_size = args.client_batch_size +client_learning_rate = args.client_learning_rate +local_server_num = args.local_server_num +config_file_path = args.config_file_path +encrypt_type = args.encrypt_type +share_secrets_ratio = args.share_secrets_ratio +cipher_time_window = args.cipher_time_window +reconstruct_secrets_threshold = args.reconstruct_secrets_threshold +dp_eps = args.dp_eps +dp_delta = args.dp_delta +dp_norm_clip = args.dp_norm_clip + +if local_server_num == -1: + local_server_num = server_num + +assert local_server_num <= server_num, "The local server number should not be bigger than total server number." + +for i in range(local_server_num): + cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && " + cmd_server += "rm -rf ${execute_path}/server_" + str(i) + "/ &&" + cmd_server += "mkdir ${execute_path}/server_" + str(i) + "/ &&" + cmd_server += "cd ${execute_path}/server_" + str(i) + "/ || exit && export GLOG_v=1 &&" + cmd_server += "python ${self_path}/../test_cross_silo_lenet.py" + cmd_server += " --device_target=" + device_target + cmd_server += " --server_mode=" + server_mode + cmd_server += " --ms_role=MS_SERVER" + cmd_server += " --worker_num=" + str(worker_num) + cmd_server += " --server_num=" + str(server_num) + cmd_server += " --scheduler_ip=" + scheduler_ip + cmd_server += " --scheduler_port=" + str(scheduler_port) + cmd_server += " --fl_server_port=" + str(fl_server_port + i) + cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold) + cmd_server += " --start_fl_job_time_window=" + str(start_fl_job_time_window) + cmd_server += " --update_model_ratio=" + str(update_model_ratio) + cmd_server += " --update_model_time_window=" + str(update_model_time_window) + cmd_server += " --fl_name=" + fl_name + cmd_server += " --fl_iteration_num=" + str(fl_iteration_num) + cmd_server += " --config_file_path=" + str(config_file_path) + cmd_server += " --client_epoch_num=" + str(client_epoch_num) + cmd_server += " --client_batch_size=" + str(client_batch_size) + cmd_server += " --client_learning_rate=" + str(client_learning_rate) + cmd_server += " --encrypt_type=" + str(encrypt_type) + cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio) + cmd_server += " --cipher_time_window=" + str(cipher_time_window) + cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold) + cmd_server += " --dp_eps=" + str(dp_eps) + cmd_server += " --dp_delta=" + str(dp_delta) + cmd_server += " --dp_norm_clip=" + str(dp_norm_clip) + cmd_server += " > server.log 2>&1 &" + + import time + time.sleep(0.3) + subprocess.call(['bash', '-c', cmd_server]) diff --git a/tests/st/fl/cross_silo_lenet/run_cross_silo_lenet_worker.py b/tests/st/fl/cross_silo_lenet/run_cross_silo_lenet_worker.py new file mode 100644 index 00000000000..b2e14650bcf --- /dev/null +++ b/tests/st/fl/cross_silo_lenet/run_cross_silo_lenet_worker.py @@ -0,0 +1,70 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import subprocess + +parser = argparse.ArgumentParser(description="Run test_cross_silo_lenet.py case") +parser.add_argument("--device_target", type=str, default="GPU") +parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING") +parser.add_argument("--worker_num", type=int, default=1) +parser.add_argument("--server_num", type=int, default=2) +parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") +parser.add_argument("--scheduler_port", type=int, default=8113) +parser.add_argument("--fl_iteration_num", type=int, default=25) +parser.add_argument("--client_epoch_num", type=int, default=20) +parser.add_argument("--worker_step_num_per_iteration", type=int, default=65) +parser.add_argument("--local_worker_num", type=int, default=-1) +parser.add_argument("--config_file_path", type=str, default="") + +args, _ = parser.parse_known_args() +device_target = args.device_target +server_mode = args.server_mode +worker_num = args.worker_num +server_num = args.server_num +scheduler_ip = args.scheduler_ip +scheduler_port = args.scheduler_port +fl_iteration_num = args.fl_iteration_num +client_epoch_num = args.client_epoch_num +worker_step_num_per_iteration = args.worker_step_num_per_iteration +local_worker_num = args.local_worker_num +config_file_path = args.config_file_path + +if local_worker_num == -1: + local_worker_num = worker_num + +assert local_worker_num <= worker_num, "The local worker number should not be bigger than total worker number." + +for i in range(local_worker_num): + cmd_worker = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && " + cmd_worker += "rm -rf ${execute_path}/worker_" + str(i) + "/ &&" + cmd_worker += "mkdir ${execute_path}/worker_" + str(i) + "/ &&" + cmd_worker += "cd ${execute_path}/worker_" + str(i) + "/ || exit && export GLOG_v=1 && " + cmd_worker += "export CUDA_VISIBLE_DEVICES=" + str(i) +" && " + cmd_worker += "python ${self_path}/../test_cross_silo_lenet.py" + cmd_worker += " --device_target=" + device_target + cmd_worker += " --server_mode=" + server_mode + cmd_worker += " --ms_role=MS_WORKER" + cmd_worker += " --worker_num=" + str(worker_num) + cmd_worker += " --server_num=" + str(server_num) + cmd_worker += " --scheduler_ip=" + scheduler_ip + cmd_worker += " --scheduler_port=" + str(scheduler_port) + cmd_worker += " --config_file_path=" + str(config_file_path) + cmd_worker += " --fl_iteration_num=" + str(fl_iteration_num) + cmd_worker += " --client_epoch_num=" + str(client_epoch_num) + cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration) + cmd_worker += " > worker.log 2>&1 &" + + subprocess.call(['bash', '-c', cmd_worker]) diff --git a/tests/st/fl/cross_silo_lenet/src/cell_wrapper.py b/tests/st/fl/cross_silo_lenet/src/cell_wrapper.py new file mode 100644 index 00000000000..2f8053d435d --- /dev/null +++ b/tests/st/fl/cross_silo_lenet/src/cell_wrapper.py @@ -0,0 +1,46 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore.nn import TrainOneStepCell + +class TrainOneStepCellForFLWorker(TrainOneStepCell): + """ + Wraps the network with federated learning operators in worker. + """ + def __init__(self, network, optimizer, sens=1.0, batch_size=32): + super(TrainOneStepCellForFLWorker, self).__init__(network, optimizer, sens) + self.batch_size = batch_size + self.start_fl_job = P.StartFLJob(batch_size) + self.update_model = P.UpdateModel() + self.get_model = P.GetModel() + self.depend = P.Depend() + + def construct(self, *inputs): + start_fl_job = self.start_fl_job() + inputs = self.depend(inputs, start_fl_job) + + loss = self.network(*inputs) + sens = F.fill(loss.dtype, loss.shape, self.sens) + grads = self.grad(self.network, self.weights)(*inputs, sens) + grads = self.grad_reducer(grads) + loss = self.depend(loss, self.optimizer(grads)) + + self.update_model(self.weights) + get_model = self.get_model(self.weights) + + loss = self.depend(loss, get_model) + return loss diff --git a/tests/st/fl/cross_silo_lenet/src/model.py b/tests/st/fl/cross_silo_lenet/src/model.py new file mode 100644 index 00000000000..aba4940499e --- /dev/null +++ b/tests/st/fl/cross_silo_lenet/src/model.py @@ -0,0 +1,72 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore.nn as nn +from mindspore.common.initializer import TruncatedNormal + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + """weight initial for conv layer""" + weight = weight_variable() + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + weight_init=weight, + has_bias=False, + pad_mode="valid", + ) + + +def fc_with_initialize(input_channels, out_channels): + """weight initial for fc layer""" + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +def weight_variable(): + """weight initial""" + return TruncatedNormal(0.02) + + +class LeNet5(nn.Cell): + def __init__(self, num_class=10, channel=3): + super(LeNet5, self).__init__() + self.num_class = num_class + self.conv1 = conv(channel, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, self.num_class) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x diff --git a/tests/st/fl/cross_silo_lenet/test_cross_silo_lenet.py b/tests/st/fl/cross_silo_lenet/test_cross_silo_lenet.py new file mode 100644 index 00000000000..b2745c21e4e --- /dev/null +++ b/tests/st/fl/cross_silo_lenet/test_cross_silo_lenet.py @@ -0,0 +1,135 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.nn import WithLossCell +from src.cell_wrapper import TrainOneStepCellForFLWorker +from src.model import LeNet5 + +parser = argparse.ArgumentParser(description="test_cross_silo_lenet") +parser.add_argument("--device_target", type=str, default="GPU") +parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING") +parser.add_argument("--ms_role", type=str, default="MS_WORKER") +parser.add_argument("--worker_num", type=int, default=1) +parser.add_argument("--server_num", type=int, default=1) +parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") +parser.add_argument("--scheduler_port", type=int, default=8113) +parser.add_argument("--fl_server_port", type=int, default=6666) +parser.add_argument("--start_fl_job_threshold", type=int, default=1) +parser.add_argument("--start_fl_job_time_window", type=int, default=3000) +parser.add_argument("--update_model_ratio", type=float, default=1.0) +parser.add_argument("--update_model_time_window", type=int, default=3000) +parser.add_argument("--fl_name", type=str, default="Lenet") +parser.add_argument("--fl_iteration_num", type=int, default=25) +parser.add_argument("--client_epoch_num", type=int, default=20) +parser.add_argument("--client_batch_size", type=int, default=32) +parser.add_argument("--client_learning_rate", type=float, default=0.1) +parser.add_argument("--worker_step_num_per_iteration", type=int, default=65) +parser.add_argument("--scheduler_manage_port", type=int, default=11202) +parser.add_argument("--config_file_path", type=str, default="") +parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") +# parameters for encrypt_type='DP_ENCRYPT' +parser.add_argument("--dp_eps", type=float, default=50.0) +parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num +parser.add_argument("--dp_norm_clip", type=float, default=1.0) +# parameters for encrypt_type='PW_ENCRYPT' +parser.add_argument("--share_secrets_ratio", type=float, default=1.0) +parser.add_argument("--cipher_time_window", type=int, default=300000) +parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) + +args, _ = parser.parse_known_args() +device_target = args.device_target +server_mode = args.server_mode +ms_role = args.ms_role +worker_num = args.worker_num +server_num = args.server_num +scheduler_ip = args.scheduler_ip +scheduler_port = args.scheduler_port +fl_server_port = args.fl_server_port +start_fl_job_threshold = args.start_fl_job_threshold +start_fl_job_time_window = args.start_fl_job_time_window +update_model_ratio = args.update_model_ratio +update_model_time_window = args.update_model_time_window +fl_name = args.fl_name +fl_iteration_num = args.fl_iteration_num +client_epoch_num = args.client_epoch_num +client_batch_size = args.client_batch_size +client_learning_rate = args.client_learning_rate +worker_step_num_per_iteration = args.worker_step_num_per_iteration +scheduler_manage_port = args.scheduler_manage_port +config_file_path = args.config_file_path +encrypt_type = args.encrypt_type +share_secrets_ratio = args.share_secrets_ratio +cipher_time_window = args.cipher_time_window +reconstruct_secrets_threshold = args.reconstruct_secrets_threshold +dp_eps = args.dp_eps +dp_delta = args.dp_delta +dp_norm_clip = args.dp_norm_clip + +ctx = { + "enable_fl": True, + "server_mode": server_mode, + "ms_role": ms_role, + "worker_num": worker_num, + "server_num": server_num, + "scheduler_ip": scheduler_ip, + "scheduler_port": scheduler_port, + "fl_server_port": fl_server_port, + "start_fl_job_threshold": start_fl_job_threshold, + "start_fl_job_time_window": start_fl_job_time_window, + "update_model_ratio": update_model_ratio, + "update_model_time_window": update_model_time_window, + "fl_name": fl_name, + "fl_iteration_num": fl_iteration_num, + "client_epoch_num": client_epoch_num, + "client_batch_size": client_batch_size, + "client_learning_rate": client_learning_rate, + "worker_step_num_per_iteration": worker_step_num_per_iteration, + "scheduler_manage_port": scheduler_manage_port, + "config_file_path": config_file_path, + "share_secrets_ratio": share_secrets_ratio, + "cipher_time_window": cipher_time_window, + "reconstruct_secrets_threshold": reconstruct_secrets_threshold, + "dp_eps": dp_eps, + "dp_delta": dp_delta, + "dp_norm_clip": dp_norm_clip, + "encrypt_type": encrypt_type +} + +context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False) +context.set_fl_context(**ctx) + +if __name__ == "__main__": + epoch = 50000 + np.random.seed(0) + network = LeNet5(62) + criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) + net_with_criterion = WithLossCell(network, criterion) + train_network = TrainOneStepCellForFLWorker(net_with_criterion, net_opt) + train_network.set_train() + losses = [] + + for _ in range(epoch): + data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32)) + label = Tensor(np.random.randint(0, 61, (32)).astype(np.int32)) + loss = train_network(data, label).asnumpy() + losses.append(loss) + print(losses)