forked from mindspore-Ecosystem/mindspore
Fix cross silo running issue.
This commit is contained in:
parent
ecbf018e1c
commit
7fa258f178
|
@ -27,8 +27,8 @@ bool IterationMetrics::Initialize() {
|
|||
config_ = std::make_unique<ps::core::FileConfiguration>(config_file_path_);
|
||||
MS_EXCEPTION_IF_NULL(config_);
|
||||
if (!config_->Initialize()) {
|
||||
MS_LOG(EXCEPTION) << "Initializing for metrics failed. Config file path " << config_file_path_
|
||||
<< " may be invalid or not exist.";
|
||||
MS_LOG(WARNING) << "Initializing for metrics failed. Config file path " << config_file_path_
|
||||
<< " may be invalid or not exist.";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -62,11 +62,13 @@ bool IterationMetrics::Initialize() {
|
|||
}
|
||||
|
||||
metrics_file_.open(metrics_file_path_, std::ios::ate | std::ios::out);
|
||||
metrics_file_.close();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IterationMetrics::Summarize() {
|
||||
metrics_file_.open(metrics_file_path_, std::ios::ate | std::ios::out);
|
||||
if (!metrics_file_.is_open()) {
|
||||
MS_LOG(ERROR) << "The metrics file is not opened.";
|
||||
return false;
|
||||
|
@ -83,6 +85,7 @@ bool IterationMetrics::Summarize() {
|
|||
js_[kIterExecutionTime] = iteration_time_cost_;
|
||||
metrics_file_ << js_ << "\n";
|
||||
(void)metrics_file_.flush();
|
||||
metrics_file_.close();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -80,13 +80,17 @@ void FLWorker::Run() {
|
|||
}
|
||||
|
||||
void FLWorker::Finalize() {
|
||||
MS_EXCEPTION_IF_NULL(worker_node_);
|
||||
if (!worker_node_->Finish()) {
|
||||
MS_LOG(ERROR) << "Worker node finishing failed.";
|
||||
if (worker_node_ == nullptr) {
|
||||
MS_LOG(INFO) << "The worker is not initialized yet.";
|
||||
return;
|
||||
}
|
||||
|
||||
// In some cases, worker calls the Finish function while other nodes don't. So timeout is acceptable.
|
||||
if (!worker_node_->Finish()) {
|
||||
MS_LOG(WARNING) << "Finishing worker node timeout.";
|
||||
}
|
||||
if (!worker_node_->Stop()) {
|
||||
MS_LOG(ERROR) << "Worker node stopping failed.";
|
||||
MS_LOG(ERROR) << "Stopping worker node failed.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -76,9 +76,10 @@ bool TcpCommunicator::Start() {
|
|||
|
||||
bool TcpCommunicator::Stop() {
|
||||
MS_EXCEPTION_IF_NULL(abstrace_node_);
|
||||
|
||||
// In some cases, server calls the Finish function while other nodes don't. So timeout is acceptable.
|
||||
if (!abstrace_node_->Finish()) {
|
||||
MS_LOG(ERROR) << "Finishing server node failed.";
|
||||
return false;
|
||||
MS_LOG(WARNING) << "Finishing server node timeout.";
|
||||
}
|
||||
if (!abstrace_node_->Stop()) {
|
||||
MS_LOG(ERROR) << "Stopping server node failed.";
|
||||
|
|
|
@ -32,18 +32,10 @@ parser.add_argument("--fl_name", type=str, default="Lenet")
|
|||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.01)
|
||||
parser.add_argument("--local_server_num", type=int, default=-1)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
# parameters for encrypt_type='DP_ENCRYPT'
|
||||
parser.add_argument("--dp_eps", type=float, default=50.0)
|
||||
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
|
||||
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
|
||||
# parameters for encrypt_type='PW_ENCRYPT'
|
||||
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--cipher_time_window", type=int, default=300000)
|
||||
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
|
||||
parser.add_argument("--dataset_path", type=str, default="")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
@ -66,12 +58,6 @@ client_learning_rate = args.client_learning_rate
|
|||
local_server_num = args.local_server_num
|
||||
config_file_path = args.config_file_path
|
||||
encrypt_type = args.encrypt_type
|
||||
share_secrets_ratio = args.share_secrets_ratio
|
||||
cipher_time_window = args.cipher_time_window
|
||||
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
|
||||
dp_eps = args.dp_eps
|
||||
dp_delta = args.dp_delta
|
||||
dp_norm_clip = args.dp_norm_clip
|
||||
dataset_path = args.dataset_path
|
||||
|
||||
if local_server_num == -1:
|
||||
|
@ -104,12 +90,6 @@ for i in range(local_server_num):
|
|||
cmd_server += " --client_batch_size=" + str(client_batch_size)
|
||||
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
|
||||
cmd_server += " --encrypt_type=" + str(encrypt_type)
|
||||
cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio)
|
||||
cmd_server += " --cipher_time_window=" + str(cipher_time_window)
|
||||
cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold)
|
||||
cmd_server += " --dp_eps=" + str(dp_eps)
|
||||
cmd_server += " --dp_delta=" + str(dp_delta)
|
||||
cmd_server += " --dp_norm_clip=" + str(dp_norm_clip)
|
||||
cmd_server += " --dataset_path=" + str(dataset_path)
|
||||
cmd_server += " --user_id=" + str(0)
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
|
|
@ -25,6 +25,8 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
|||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.01)
|
||||
parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
|
||||
parser.add_argument("--local_worker_num", type=int, default=-1)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
|
@ -39,6 +41,8 @@ scheduler_ip = args.scheduler_ip
|
|||
scheduler_port = args.scheduler_port
|
||||
fl_iteration_num = args.fl_iteration_num
|
||||
client_epoch_num = args.client_epoch_num
|
||||
client_batch_size = args.client_batch_size
|
||||
client_learning_rate = args.client_learning_rate
|
||||
worker_step_num_per_iteration = args.worker_step_num_per_iteration
|
||||
local_worker_num = args.local_worker_num
|
||||
config_file_path = args.config_file_path
|
||||
|
@ -65,6 +69,8 @@ for i in range(local_worker_num):
|
|||
cmd_worker += " --config_file_path=" + str(config_file_path)
|
||||
cmd_worker += " --fl_iteration_num=" + str(fl_iteration_num)
|
||||
cmd_worker += " --client_epoch_num=" + str(client_epoch_num)
|
||||
cmd_worker += " --client_batch_size=" + str(client_batch_size)
|
||||
cmd_worker += " --client_learning_rate=" + str(client_learning_rate)
|
||||
cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration)
|
||||
cmd_worker += " --dataset_path=" + str(dataset_path)
|
||||
cmd_worker += " --user_id=" + str(i)
|
||||
|
|
|
@ -47,25 +47,22 @@ parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
|
|||
parser.add_argument("--update_model_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--update_model_time_window", type=int, default=3000)
|
||||
parser.add_argument("--fl_name", type=str, default="Lenet")
|
||||
# fl_iteration_num is also used as the global epoch number for Worker.
|
||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
# client_batch_size is also used as the batch size of each mini-batch for Worker.
|
||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
||||
# client_learning_rate is also used as the learning rate for Worker.
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.01)
|
||||
parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
|
||||
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
parser.add_argument("--dp_eps", type=float, default=50.0)
|
||||
parser.add_argument("--dp_delta", type=float, default=0.01)
|
||||
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
|
||||
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--cipher_time_window", type=int, default=300000)
|
||||
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
|
||||
parser.add_argument("--dataset_path", type=str, default="")
|
||||
# The user_id is used to set each worker's dataset path.
|
||||
parser.add_argument("--user_id", type=str, default="0")
|
||||
|
||||
parser.add_argument('--img_size', type=int, default=(32, 32, 1), help='the image size of (h,w,c)')
|
||||
parser.add_argument('--batch_size', type=float, default=32, help='batch size')
|
||||
parser.add_argument('--repeat_size', type=int, default=1, help='the repeat size when create the dataLoader')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
@ -90,12 +87,6 @@ worker_step_num_per_iteration = args.worker_step_num_per_iteration
|
|||
scheduler_manage_port = args.scheduler_manage_port
|
||||
config_file_path = args.config_file_path
|
||||
encrypt_type = args.encrypt_type
|
||||
share_secrets_ratio = args.share_secrets_ratio
|
||||
cipher_time_window = args.cipher_time_window
|
||||
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
|
||||
dp_eps = args.dp_eps
|
||||
dp_delta = args.dp_delta
|
||||
dp_norm_clip = args.dp_norm_clip
|
||||
dataset_path = args.dataset_path
|
||||
user_id = args.user_id
|
||||
|
||||
|
@ -120,12 +111,6 @@ ctx = {
|
|||
"worker_step_num_per_iteration": worker_step_num_per_iteration,
|
||||
"scheduler_manage_port": scheduler_manage_port,
|
||||
"config_file_path": config_file_path,
|
||||
"share_secrets_ratio": share_secrets_ratio,
|
||||
"cipher_time_window": cipher_time_window,
|
||||
"reconstruct_secrets_threshold": reconstruct_secrets_threshold,
|
||||
"dp_eps": dp_eps,
|
||||
"dp_delta": dp_delta,
|
||||
"dp_norm_clip": dp_norm_clip,
|
||||
"encrypt_type": encrypt_type
|
||||
}
|
||||
|
||||
|
@ -314,13 +299,13 @@ class UpdateAndGetModel(nn.Cell):
|
|||
|
||||
|
||||
def train():
|
||||
epoch = client_epoch_num
|
||||
epoch = fl_iteration_num
|
||||
network = LeNet5(62, 3)
|
||||
|
||||
# define the loss function
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
# define the optimizer
|
||||
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
|
||||
net_opt = nn.Momentum(network.trainable_params(), client_learning_rate, 0.9)
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy(), 'Loss': nn.Loss()})
|
||||
|
||||
ds.config.set_seed(1)
|
||||
|
@ -329,7 +314,7 @@ def train():
|
|||
train_path = os.path.join(data_root_path, user, "train")
|
||||
test_path = os.path.join(data_root_path, user, "test")
|
||||
|
||||
dataset = create_dataset_from_folder(train_path, args.img_size, args.batch_size, args.repeat_size)
|
||||
dataset = create_dataset_from_folder(train_path, args.img_size, args.client_batch_size, args.repeat_size)
|
||||
print("size is ", dataset.get_dataset_size(), flush=True)
|
||||
num_batches = dataset.get_dataset_size()
|
||||
|
||||
|
@ -341,7 +326,7 @@ def train():
|
|||
|
||||
for iter_num in range(fl_iteration_num):
|
||||
if context.get_fl_context("ms_role") == "MS_WORKER":
|
||||
start_fl_job = StartFLJob(dataset.get_dataset_size() * args.batch_size)
|
||||
start_fl_job = StartFLJob(dataset.get_dataset_size() * args.client_batch_size)
|
||||
start_fl_job()
|
||||
|
||||
for _ in range(epoch):
|
||||
|
@ -356,8 +341,8 @@ def train():
|
|||
ckpt_name = os.path.join(ckpt_path, ckpt_name)
|
||||
save_checkpoint(network, ckpt_name)
|
||||
|
||||
train_acc, _ = evalute_process(model, train_path, args.img_size, args.batch_size)
|
||||
test_acc, _ = evalute_process(model, test_path, args.img_size, args.batch_size)
|
||||
train_acc, _ = evalute_process(model, train_path, args.img_size, args.client_batch_size)
|
||||
test_acc, _ = evalute_process(model, test_path, args.img_size, args.client_batch_size)
|
||||
loss_list = loss_cb.get_loss()
|
||||
loss = sum(loss_list) / len(loss_list)
|
||||
print('local epoch: {}, loss: {}, trian acc: {}, test acc: {}'.format(iter_num, loss, train_acc, test_acc),
|
||||
|
|
|
@ -32,18 +32,10 @@ parser.add_argument("--fl_name", type=str, default="Lenet")
|
|||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.01)
|
||||
parser.add_argument("--local_server_num", type=int, default=-1)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
# parameters for encrypt_type='DP_ENCRYPT'
|
||||
parser.add_argument("--dp_eps", type=float, default=50.0)
|
||||
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
|
||||
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
|
||||
# parameters for encrypt_type='PW_ENCRYPT'
|
||||
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--cipher_time_window", type=int, default=300000)
|
||||
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -65,12 +57,6 @@ client_learning_rate = args.client_learning_rate
|
|||
local_server_num = args.local_server_num
|
||||
config_file_path = args.config_file_path
|
||||
encrypt_type = args.encrypt_type
|
||||
share_secrets_ratio = args.share_secrets_ratio
|
||||
cipher_time_window = args.cipher_time_window
|
||||
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
|
||||
dp_eps = args.dp_eps
|
||||
dp_delta = args.dp_delta
|
||||
dp_norm_clip = args.dp_norm_clip
|
||||
|
||||
if local_server_num == -1:
|
||||
local_server_num = server_num
|
||||
|
@ -102,12 +88,6 @@ for i in range(local_server_num):
|
|||
cmd_server += " --client_batch_size=" + str(client_batch_size)
|
||||
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
|
||||
cmd_server += " --encrypt_type=" + str(encrypt_type)
|
||||
cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio)
|
||||
cmd_server += " --cipher_time_window=" + str(cipher_time_window)
|
||||
cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold)
|
||||
cmd_server += " --dp_eps=" + str(dp_eps)
|
||||
cmd_server += " --dp_delta=" + str(dp_delta)
|
||||
cmd_server += " --dp_norm_clip=" + str(dp_norm_clip)
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
import time
|
||||
|
|
|
@ -25,6 +25,8 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
|
|||
parser.add_argument("--scheduler_port", type=int, default=8113)
|
||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.01)
|
||||
parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
|
||||
parser.add_argument("--local_worker_num", type=int, default=-1)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
|
@ -38,6 +40,8 @@ scheduler_ip = args.scheduler_ip
|
|||
scheduler_port = args.scheduler_port
|
||||
fl_iteration_num = args.fl_iteration_num
|
||||
client_epoch_num = args.client_epoch_num
|
||||
client_batch_size = args.client_batch_size
|
||||
client_learning_rate = args.client_learning_rate
|
||||
worker_step_num_per_iteration = args.worker_step_num_per_iteration
|
||||
local_worker_num = args.local_worker_num
|
||||
config_file_path = args.config_file_path
|
||||
|
@ -64,6 +68,8 @@ for i in range(local_worker_num):
|
|||
cmd_worker += " --config_file_path=" + str(config_file_path)
|
||||
cmd_worker += " --fl_iteration_num=" + str(fl_iteration_num)
|
||||
cmd_worker += " --client_epoch_num=" + str(client_epoch_num)
|
||||
cmd_worker += " --client_batch_size=" + str(client_batch_size)
|
||||
cmd_worker += " --client_learning_rate=" + str(client_learning_rate)
|
||||
cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration)
|
||||
cmd_worker += " > worker.log 2>&1 &"
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
|
||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||
|
@ -70,3 +71,25 @@ class LeNet5(nn.Cell):
|
|||
x = self.relu(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
class StartFLJob(nn.Cell):
|
||||
def __init__(self, data_size):
|
||||
super(StartFLJob, self).__init__()
|
||||
self.start_fl_job = P.StartFLJob(data_size)
|
||||
|
||||
def construct(self):
|
||||
return self.start_fl_job()
|
||||
|
||||
|
||||
class UpdateAndGetModel(nn.Cell):
|
||||
def __init__(self, weights):
|
||||
super(UpdateAndGetModel, self).__init__()
|
||||
self.update_model = P.UpdateModel()
|
||||
self.get_model = P.GetModel()
|
||||
self.weights = weights
|
||||
|
||||
def construct(self):
|
||||
self.update_model(self.weights)
|
||||
get_model = self.get_model(self.weights)
|
||||
return get_model
|
||||
|
|
|
@ -19,9 +19,8 @@ import numpy as np
|
|||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import WithLossCell
|
||||
from src.cell_wrapper import TrainOneStepCellForFLWorker
|
||||
from src.model import LeNet5
|
||||
from mindspore.nn import WithLossCell, TrainOneStepCell
|
||||
from src.model import LeNet5, StartFLJob, UpdateAndGetModel
|
||||
|
||||
parser = argparse.ArgumentParser(description="test_cross_silo_lenet")
|
||||
parser.add_argument("--device_target", type=str, default="GPU")
|
||||
|
@ -37,22 +36,17 @@ parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
|
|||
parser.add_argument("--update_model_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--update_model_time_window", type=int, default=3000)
|
||||
parser.add_argument("--fl_name", type=str, default="Lenet")
|
||||
# fl_iteration_num is also used as the global epoch number for Worker.
|
||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
# client_batch_size is also used as the batch size of each mini-batch for Worker.
|
||||
parser.add_argument("--client_batch_size", type=int, default=32)
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
||||
# client_learning_rate is also used as the learning rate for Worker.
|
||||
parser.add_argument("--client_learning_rate", type=float, default=0.01)
|
||||
parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
|
||||
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
# parameters for encrypt_type='DP_ENCRYPT'
|
||||
parser.add_argument("--dp_eps", type=float, default=50.0)
|
||||
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
|
||||
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
|
||||
# parameters for encrypt_type='PW_ENCRYPT'
|
||||
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--cipher_time_window", type=int, default=300000)
|
||||
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -76,12 +70,6 @@ worker_step_num_per_iteration = args.worker_step_num_per_iteration
|
|||
scheduler_manage_port = args.scheduler_manage_port
|
||||
config_file_path = args.config_file_path
|
||||
encrypt_type = args.encrypt_type
|
||||
share_secrets_ratio = args.share_secrets_ratio
|
||||
cipher_time_window = args.cipher_time_window
|
||||
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
|
||||
dp_eps = args.dp_eps
|
||||
dp_delta = args.dp_delta
|
||||
dp_norm_clip = args.dp_norm_clip
|
||||
|
||||
ctx = {
|
||||
"enable_fl": True,
|
||||
|
@ -104,12 +92,6 @@ ctx = {
|
|||
"worker_step_num_per_iteration": worker_step_num_per_iteration,
|
||||
"scheduler_manage_port": scheduler_manage_port,
|
||||
"config_file_path": config_file_path,
|
||||
"share_secrets_ratio": share_secrets_ratio,
|
||||
"cipher_time_window": cipher_time_window,
|
||||
"reconstruct_secrets_threshold": reconstruct_secrets_threshold,
|
||||
"dp_eps": dp_eps,
|
||||
"dp_delta": dp_delta,
|
||||
"dp_norm_clip": dp_norm_clip,
|
||||
"encrypt_type": encrypt_type
|
||||
}
|
||||
|
||||
|
@ -117,19 +99,27 @@ context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_g
|
|||
context.set_fl_context(**ctx)
|
||||
|
||||
if __name__ == "__main__":
|
||||
epoch = 50000
|
||||
epoch = fl_iteration_num
|
||||
np.random.seed(0)
|
||||
network = LeNet5(62)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
|
||||
net_opt = nn.Momentum(network.trainable_params(), client_learning_rate, 0.9)
|
||||
net_with_criterion = WithLossCell(network, criterion)
|
||||
train_network = TrainOneStepCellForFLWorker(net_with_criterion, net_opt)
|
||||
train_network = TrainOneStepCell(net_with_criterion, net_opt)
|
||||
train_network.set_train()
|
||||
losses = []
|
||||
|
||||
for _ in range(epoch):
|
||||
data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32))
|
||||
label = Tensor(np.random.randint(0, 61, (32)).astype(np.int32))
|
||||
if context.get_fl_context("ms_role") == "MS_WORKER":
|
||||
start_fl_job = StartFLJob(dataset.get_dataset_size() * args.client_batch_size)
|
||||
start_fl_job()
|
||||
|
||||
data = Tensor(np.random.rand(client_batch_size, 3, 32, 32).astype(np.float32))
|
||||
label = Tensor(np.random.randint(0, 61, (client_batch_size)).astype(np.int32))
|
||||
loss = train_network(data, label).asnumpy()
|
||||
losses.append(loss)
|
||||
|
||||
if context.get_fl_context("ms_role") == "MS_WORKER":
|
||||
update_and_get_model = UpdateAndGetModel(net_opt.parameters)
|
||||
update_and_get_model()
|
||||
print(losses)
|
||||
|
|
Loading…
Reference in New Issue