forked from mindspore-Ecosystem/mindspore
Add test case for cross-silo lenet
This commit is contained in:
parent
349d3f85cb
commit
fc8a16a149
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"recovery": {
|
||||
"storge_type": 1,
|
||||
"storage_file_path": "recovery.json"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser(description="Finish test_cross_silo_lenet.py case")
|
||||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
scheduler_port = args.scheduler_port
|
||||
|
||||
cmd = "pid=`ps -ef|grep \"scheduler_port=" + str(scheduler_port) + "\" "
|
||||
cmd += " | grep -v \"grep\" | grep -v \"finish\" |awk '{print $2}'` && "
|
||||
cmd += "for id in $pid; do kill -9 $id && echo \"killed $id\"; done"
|
||||
|
||||
subprocess.call(['bash', '-c', cmd])
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run test_cross_silo_lenet.py case")
|
||||
parser.add_argument("--device_target", type=str, default="CPU")
|
||||
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
|
||||
parser.add_argument("--worker_num", type=int, default=1)
|
||||
parser.add_argument("--server_num", type=int, default=2)
|
||||
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
server_mode = args.server_mode
|
||||
worker_num = args.worker_num
|
||||
server_num = args.server_num
|
||||
scheduler_ip = args.scheduler_ip
|
||||
scheduler_port = args.scheduler_port
|
||||
scheduler_manage_port = args.scheduler_manage_port
|
||||
config_file_path = args.config_file_path
|
||||
|
||||
cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&"
|
||||
cmd_sched += "mkdir ${execute_path}/scheduler/ &&"
|
||||
cmd_sched += "cd ${execute_path}/scheduler/ || exit && export GLOG_v=1 &&"
|
||||
cmd_sched += "python ${self_path}/../test_cross_silo_lenet.py"
|
||||
cmd_sched += " --device_target=" + device_target
|
||||
cmd_sched += " --server_mode=" + server_mode
|
||||
cmd_sched += " --ms_role=MS_SCHED"
|
||||
cmd_sched += " --worker_num=" + str(worker_num)
|
||||
cmd_sched += " --server_num=" + str(server_num)
|
||||
cmd_sched += " --config_file_path=" + str(config_file_path)
|
||||
cmd_sched += " --scheduler_ip=" + scheduler_ip
|
||||
cmd_sched += " --scheduler_port=" + str(scheduler_port)
|
||||
cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port)
|
||||
cmd_sched += " > scheduler.log 2>&1 &"
|
||||
|
||||
subprocess.call(['bash', '-c', cmd_sched])
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run test_cross_silo_lenet.py case")
|
||||
parser.add_argument("--device_target", type=str, default="CPU")
|
||||
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
|
||||
parser.add_argument("--worker_num", type=int, default=1)
|
||||
parser.add_argument("--server_num", type=int, default=2)
|
||||
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
parser.add_argument("--fl_server_port", type=int, default=6666)
|
||||
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
||||
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
|
||||
parser.add_argument("--update_model_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--update_model_time_window", type=int, default=3000)
|
||||
parser.add_argument("--fl_name", type=str, default="Lenet")
|
||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
||||
parser.add_argument("--local_server_num", type=int, default=-1)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
# parameters for encrypt_type='DP_ENCRYPT'
|
||||
parser.add_argument("--dp_eps", type=float, default=50.0)
|
||||
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
|
||||
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
|
||||
# parameters for encrypt_type='PW_ENCRYPT'
|
||||
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--cipher_time_window", type=int, default=300000)
|
||||
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
server_mode = args.server_mode
|
||||
worker_num = args.worker_num
|
||||
server_num = args.server_num
|
||||
scheduler_ip = args.scheduler_ip
|
||||
scheduler_port = args.scheduler_port
|
||||
fl_server_port = args.fl_server_port
|
||||
start_fl_job_threshold = args.start_fl_job_threshold
|
||||
start_fl_job_time_window = args.start_fl_job_time_window
|
||||
update_model_ratio = args.update_model_ratio
|
||||
update_model_time_window = args.update_model_time_window
|
||||
fl_name = args.fl_name
|
||||
fl_iteration_num = args.fl_iteration_num
|
||||
client_epoch_num = args.client_epoch_num
|
||||
client_batch_size = args.client_batch_size
|
||||
client_learning_rate = args.client_learning_rate
|
||||
local_server_num = args.local_server_num
|
||||
config_file_path = args.config_file_path
|
||||
encrypt_type = args.encrypt_type
|
||||
share_secrets_ratio = args.share_secrets_ratio
|
||||
cipher_time_window = args.cipher_time_window
|
||||
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
|
||||
dp_eps = args.dp_eps
|
||||
dp_delta = args.dp_delta
|
||||
dp_norm_clip = args.dp_norm_clip
|
||||
|
||||
if local_server_num == -1:
|
||||
local_server_num = server_num
|
||||
|
||||
assert local_server_num <= server_num, "The local server number should not be bigger than total server number."
|
||||
|
||||
for i in range(local_server_num):
|
||||
cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && "
|
||||
cmd_server += "rm -rf ${execute_path}/server_" + str(i) + "/ &&"
|
||||
cmd_server += "mkdir ${execute_path}/server_" + str(i) + "/ &&"
|
||||
cmd_server += "cd ${execute_path}/server_" + str(i) + "/ || exit && export GLOG_v=1 &&"
|
||||
cmd_server += "python ${self_path}/../test_cross_silo_lenet.py"
|
||||
cmd_server += " --device_target=" + device_target
|
||||
cmd_server += " --server_mode=" + server_mode
|
||||
cmd_server += " --ms_role=MS_SERVER"
|
||||
cmd_server += " --worker_num=" + str(worker_num)
|
||||
cmd_server += " --server_num=" + str(server_num)
|
||||
cmd_server += " --scheduler_ip=" + scheduler_ip
|
||||
cmd_server += " --scheduler_port=" + str(scheduler_port)
|
||||
cmd_server += " --fl_server_port=" + str(fl_server_port + i)
|
||||
cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold)
|
||||
cmd_server += " --start_fl_job_time_window=" + str(start_fl_job_time_window)
|
||||
cmd_server += " --update_model_ratio=" + str(update_model_ratio)
|
||||
cmd_server += " --update_model_time_window=" + str(update_model_time_window)
|
||||
cmd_server += " --fl_name=" + fl_name
|
||||
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)
|
||||
cmd_server += " --config_file_path=" + str(config_file_path)
|
||||
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
|
||||
cmd_server += " --client_batch_size=" + str(client_batch_size)
|
||||
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
|
||||
cmd_server += " --encrypt_type=" + str(encrypt_type)
|
||||
cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio)
|
||||
cmd_server += " --cipher_time_window=" + str(cipher_time_window)
|
||||
cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold)
|
||||
cmd_server += " --dp_eps=" + str(dp_eps)
|
||||
cmd_server += " --dp_delta=" + str(dp_delta)
|
||||
cmd_server += " --dp_norm_clip=" + str(dp_norm_clip)
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
import time
|
||||
time.sleep(0.3)
|
||||
subprocess.call(['bash', '-c', cmd_server])
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run test_cross_silo_lenet.py case")
|
||||
parser.add_argument("--device_target", type=str, default="GPU")
|
||||
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
|
||||
parser.add_argument("--worker_num", type=int, default=1)
|
||||
parser.add_argument("--server_num", type=int, default=2)
|
||||
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
|
||||
parser.add_argument("--local_worker_num", type=int, default=-1)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
server_mode = args.server_mode
|
||||
worker_num = args.worker_num
|
||||
server_num = args.server_num
|
||||
scheduler_ip = args.scheduler_ip
|
||||
scheduler_port = args.scheduler_port
|
||||
fl_iteration_num = args.fl_iteration_num
|
||||
client_epoch_num = args.client_epoch_num
|
||||
worker_step_num_per_iteration = args.worker_step_num_per_iteration
|
||||
local_worker_num = args.local_worker_num
|
||||
config_file_path = args.config_file_path
|
||||
|
||||
if local_worker_num == -1:
|
||||
local_worker_num = worker_num
|
||||
|
||||
assert local_worker_num <= worker_num, "The local worker number should not be bigger than total worker number."
|
||||
|
||||
for i in range(local_worker_num):
|
||||
cmd_worker = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && "
|
||||
cmd_worker += "rm -rf ${execute_path}/worker_" + str(i) + "/ &&"
|
||||
cmd_worker += "mkdir ${execute_path}/worker_" + str(i) + "/ &&"
|
||||
cmd_worker += "cd ${execute_path}/worker_" + str(i) + "/ || exit && export GLOG_v=1 && "
|
||||
cmd_worker += "export CUDA_VISIBLE_DEVICES=" + str(i) +" && "
|
||||
cmd_worker += "python ${self_path}/../test_cross_silo_lenet.py"
|
||||
cmd_worker += " --device_target=" + device_target
|
||||
cmd_worker += " --server_mode=" + server_mode
|
||||
cmd_worker += " --ms_role=MS_WORKER"
|
||||
cmd_worker += " --worker_num=" + str(worker_num)
|
||||
cmd_worker += " --server_num=" + str(server_num)
|
||||
cmd_worker += " --scheduler_ip=" + scheduler_ip
|
||||
cmd_worker += " --scheduler_port=" + str(scheduler_port)
|
||||
cmd_worker += " --config_file_path=" + str(config_file_path)
|
||||
cmd_worker += " --fl_iteration_num=" + str(fl_iteration_num)
|
||||
cmd_worker += " --client_epoch_num=" + str(client_epoch_num)
|
||||
cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration)
|
||||
cmd_worker += " > worker.log 2>&1 &"
|
||||
|
||||
subprocess.call(['bash', '-c', cmd_worker])
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn import TrainOneStepCell
|
||||
|
||||
class TrainOneStepCellForFLWorker(TrainOneStepCell):
|
||||
"""
|
||||
Wraps the network with federated learning operators in worker.
|
||||
"""
|
||||
def __init__(self, network, optimizer, sens=1.0, batch_size=32):
|
||||
super(TrainOneStepCellForFLWorker, self).__init__(network, optimizer, sens)
|
||||
self.batch_size = batch_size
|
||||
self.start_fl_job = P.StartFLJob(batch_size)
|
||||
self.update_model = P.UpdateModel()
|
||||
self.get_model = P.GetModel()
|
||||
self.depend = P.Depend()
|
||||
|
||||
def construct(self, *inputs):
|
||||
start_fl_job = self.start_fl_job()
|
||||
inputs = self.depend(inputs, start_fl_job)
|
||||
|
||||
loss = self.network(*inputs)
|
||||
sens = F.fill(loss.dtype, loss.shape, self.sens)
|
||||
grads = self.grad(self.network, self.weights)(*inputs, sens)
|
||||
grads = self.grad_reducer(grads)
|
||||
loss = self.depend(loss, self.optimizer(grads))
|
||||
|
||||
self.update_model(self.weights)
|
||||
get_model = self.get_model(self.weights)
|
||||
|
||||
loss = self.depend(loss, get_model)
|
||||
return loss
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
|
||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||
"""weight initial for conv layer"""
|
||||
weight = weight_variable()
|
||||
return nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
weight_init=weight,
|
||||
has_bias=False,
|
||||
pad_mode="valid",
|
||||
)
|
||||
|
||||
|
||||
def fc_with_initialize(input_channels, out_channels):
|
||||
"""weight initial for fc layer"""
|
||||
weight = weight_variable()
|
||||
bias = weight_variable()
|
||||
return nn.Dense(input_channels, out_channels, weight, bias)
|
||||
|
||||
|
||||
def weight_variable():
|
||||
"""weight initial"""
|
||||
return TruncatedNormal(0.02)
|
||||
|
||||
|
||||
class LeNet5(nn.Cell):
|
||||
def __init__(self, num_class=10, channel=3):
|
||||
super(LeNet5, self).__init__()
|
||||
self.num_class = num_class
|
||||
self.conv1 = conv(channel, 6, 5)
|
||||
self.conv2 = conv(6, 16, 5)
|
||||
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
|
||||
self.fc2 = fc_with_initialize(120, 84)
|
||||
self.fc3 = fc_with_initialize(84, self.num_class)
|
||||
self.relu = nn.ReLU()
|
||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.conv2(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
|
@ -0,0 +1,135 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import WithLossCell
|
||||
from src.cell_wrapper import TrainOneStepCellForFLWorker
|
||||
from src.model import LeNet5
|
||||
|
||||
parser = argparse.ArgumentParser(description="test_cross_silo_lenet")
|
||||
parser.add_argument("--device_target", type=str, default="GPU")
|
||||
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
|
||||
parser.add_argument("--ms_role", type=str, default="MS_WORKER")
|
||||
parser.add_argument("--worker_num", type=int, default=1)
|
||||
parser.add_argument("--server_num", type=int, default=1)
|
||||
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
parser.add_argument("--fl_server_port", type=int, default=6666)
|
||||
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
||||
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
|
||||
parser.add_argument("--update_model_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--update_model_time_window", type=int, default=3000)
|
||||
parser.add_argument("--fl_name", type=str, default="Lenet")
|
||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
||||
parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
|
||||
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
# parameters for encrypt_type='DP_ENCRYPT'
|
||||
parser.add_argument("--dp_eps", type=float, default=50.0)
|
||||
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
|
||||
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
|
||||
# parameters for encrypt_type='PW_ENCRYPT'
|
||||
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--cipher_time_window", type=int, default=300000)
|
||||
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
server_mode = args.server_mode
|
||||
ms_role = args.ms_role
|
||||
worker_num = args.worker_num
|
||||
server_num = args.server_num
|
||||
scheduler_ip = args.scheduler_ip
|
||||
scheduler_port = args.scheduler_port
|
||||
fl_server_port = args.fl_server_port
|
||||
start_fl_job_threshold = args.start_fl_job_threshold
|
||||
start_fl_job_time_window = args.start_fl_job_time_window
|
||||
update_model_ratio = args.update_model_ratio
|
||||
update_model_time_window = args.update_model_time_window
|
||||
fl_name = args.fl_name
|
||||
fl_iteration_num = args.fl_iteration_num
|
||||
client_epoch_num = args.client_epoch_num
|
||||
client_batch_size = args.client_batch_size
|
||||
client_learning_rate = args.client_learning_rate
|
||||
worker_step_num_per_iteration = args.worker_step_num_per_iteration
|
||||
scheduler_manage_port = args.scheduler_manage_port
|
||||
config_file_path = args.config_file_path
|
||||
encrypt_type = args.encrypt_type
|
||||
share_secrets_ratio = args.share_secrets_ratio
|
||||
cipher_time_window = args.cipher_time_window
|
||||
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
|
||||
dp_eps = args.dp_eps
|
||||
dp_delta = args.dp_delta
|
||||
dp_norm_clip = args.dp_norm_clip
|
||||
|
||||
ctx = {
|
||||
"enable_fl": True,
|
||||
"server_mode": server_mode,
|
||||
"ms_role": ms_role,
|
||||
"worker_num": worker_num,
|
||||
"server_num": server_num,
|
||||
"scheduler_ip": scheduler_ip,
|
||||
"scheduler_port": scheduler_port,
|
||||
"fl_server_port": fl_server_port,
|
||||
"start_fl_job_threshold": start_fl_job_threshold,
|
||||
"start_fl_job_time_window": start_fl_job_time_window,
|
||||
"update_model_ratio": update_model_ratio,
|
||||
"update_model_time_window": update_model_time_window,
|
||||
"fl_name": fl_name,
|
||||
"fl_iteration_num": fl_iteration_num,
|
||||
"client_epoch_num": client_epoch_num,
|
||||
"client_batch_size": client_batch_size,
|
||||
"client_learning_rate": client_learning_rate,
|
||||
"worker_step_num_per_iteration": worker_step_num_per_iteration,
|
||||
"scheduler_manage_port": scheduler_manage_port,
|
||||
"config_file_path": config_file_path,
|
||||
"share_secrets_ratio": share_secrets_ratio,
|
||||
"cipher_time_window": cipher_time_window,
|
||||
"reconstruct_secrets_threshold": reconstruct_secrets_threshold,
|
||||
"dp_eps": dp_eps,
|
||||
"dp_delta": dp_delta,
|
||||
"dp_norm_clip": dp_norm_clip,
|
||||
"encrypt_type": encrypt_type
|
||||
}
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)
|
||||
context.set_fl_context(**ctx)
|
||||
|
||||
if __name__ == "__main__":
|
||||
epoch = 50000
|
||||
np.random.seed(0)
|
||||
network = LeNet5(62)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
|
||||
net_with_criterion = WithLossCell(network, criterion)
|
||||
train_network = TrainOneStepCellForFLWorker(net_with_criterion, net_opt)
|
||||
train_network.set_train()
|
||||
losses = []
|
||||
|
||||
for _ in range(epoch):
|
||||
data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32))
|
||||
label = Tensor(np.random.randint(0, 61, (32)).astype(np.int32))
|
||||
loss = train_network(data, label).asnumpy()
|
||||
losses.append(loss)
|
||||
print(losses)
|
Loading…
Reference in New Issue