Fix cross silo running issue.

This commit is contained in:
ZPaC 2021-09-09 14:42:40 +08:00
parent ecbf018e1c
commit 7fa258f178
10 changed files with 83 additions and 105 deletions

View File

@ -27,8 +27,8 @@ bool IterationMetrics::Initialize() {
config_ = std::make_unique<ps::core::FileConfiguration>(config_file_path_);
MS_EXCEPTION_IF_NULL(config_);
if (!config_->Initialize()) {
MS_LOG(EXCEPTION) << "Initializing for metrics failed. Config file path " << config_file_path_
<< " may be invalid or not exist.";
MS_LOG(WARNING) << "Initializing for metrics failed. Config file path " << config_file_path_
<< " may be invalid or not exist.";
return false;
}
@ -62,11 +62,13 @@ bool IterationMetrics::Initialize() {
}
metrics_file_.open(metrics_file_path_, std::ios::ate | std::ios::out);
metrics_file_.close();
}
return true;
}
bool IterationMetrics::Summarize() {
metrics_file_.open(metrics_file_path_, std::ios::ate | std::ios::out);
if (!metrics_file_.is_open()) {
MS_LOG(ERROR) << "The metrics file is not opened.";
return false;
@ -83,6 +85,7 @@ bool IterationMetrics::Summarize() {
js_[kIterExecutionTime] = iteration_time_cost_;
metrics_file_ << js_ << "\n";
(void)metrics_file_.flush();
metrics_file_.close();
return true;
}

View File

@ -80,13 +80,17 @@ void FLWorker::Run() {
}
void FLWorker::Finalize() {
MS_EXCEPTION_IF_NULL(worker_node_);
if (!worker_node_->Finish()) {
MS_LOG(ERROR) << "Worker node finishing failed.";
if (worker_node_ == nullptr) {
MS_LOG(INFO) << "The worker is not initialized yet.";
return;
}
// In some cases, worker calls the Finish function while other nodes don't. So timeout is acceptable.
if (!worker_node_->Finish()) {
MS_LOG(WARNING) << "Finishing worker node timeout.";
}
if (!worker_node_->Stop()) {
MS_LOG(ERROR) << "Worker node stopping failed.";
MS_LOG(ERROR) << "Stopping worker node failed.";
return;
}
}

View File

@ -76,9 +76,10 @@ bool TcpCommunicator::Start() {
bool TcpCommunicator::Stop() {
MS_EXCEPTION_IF_NULL(abstrace_node_);
// In some cases, server calls the Finish function while other nodes don't. So timeout is acceptable.
if (!abstrace_node_->Finish()) {
MS_LOG(ERROR) << "Finishing server node failed.";
return false;
MS_LOG(WARNING) << "Finishing server node timeout.";
}
if (!abstrace_node_->Stop()) {
MS_LOG(ERROR) << "Stopping server node failed.";

View File

@ -32,18 +32,10 @@ parser.add_argument("--fl_name", type=str, default="Lenet")
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--client_learning_rate", type=float, default=0.01)
parser.add_argument("--local_server_num", type=int, default=-1)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
# parameters for encrypt_type='DP_ENCRYPT'
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
# parameters for encrypt_type='PW_ENCRYPT'
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
parser.add_argument("--dataset_path", type=str, default="")
args, _ = parser.parse_known_args()
@ -66,12 +58,6 @@ client_learning_rate = args.client_learning_rate
local_server_num = args.local_server_num
config_file_path = args.config_file_path
encrypt_type = args.encrypt_type
share_secrets_ratio = args.share_secrets_ratio
cipher_time_window = args.cipher_time_window
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
dp_eps = args.dp_eps
dp_delta = args.dp_delta
dp_norm_clip = args.dp_norm_clip
dataset_path = args.dataset_path
if local_server_num == -1:
@ -104,12 +90,6 @@ for i in range(local_server_num):
cmd_server += " --client_batch_size=" + str(client_batch_size)
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
cmd_server += " --encrypt_type=" + str(encrypt_type)
cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio)
cmd_server += " --cipher_time_window=" + str(cipher_time_window)
cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold)
cmd_server += " --dp_eps=" + str(dp_eps)
cmd_server += " --dp_delta=" + str(dp_delta)
cmd_server += " --dp_norm_clip=" + str(dp_norm_clip)
cmd_server += " --dataset_path=" + str(dataset_path)
cmd_server += " --user_id=" + str(0)
cmd_server += " > server.log 2>&1 &"

View File

@ -25,6 +25,8 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.01)
parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
parser.add_argument("--local_worker_num", type=int, default=-1)
parser.add_argument("--config_file_path", type=str, default="")
@ -39,6 +41,8 @@ scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
client_batch_size = args.client_batch_size
client_learning_rate = args.client_learning_rate
worker_step_num_per_iteration = args.worker_step_num_per_iteration
local_worker_num = args.local_worker_num
config_file_path = args.config_file_path
@ -65,6 +69,8 @@ for i in range(local_worker_num):
cmd_worker += " --config_file_path=" + str(config_file_path)
cmd_worker += " --fl_iteration_num=" + str(fl_iteration_num)
cmd_worker += " --client_epoch_num=" + str(client_epoch_num)
cmd_worker += " --client_batch_size=" + str(client_batch_size)
cmd_worker += " --client_learning_rate=" + str(client_learning_rate)
cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration)
cmd_worker += " --dataset_path=" + str(dataset_path)
cmd_worker += " --user_id=" + str(i)

View File

@ -47,25 +47,22 @@ parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
parser.add_argument("--update_model_ratio", type=float, default=1.0)
parser.add_argument("--update_model_time_window", type=int, default=3000)
parser.add_argument("--fl_name", type=str, default="Lenet")
# fl_iteration_num is also used as the global epoch number for Worker.
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
# client_batch_size is also used as the batch size of each mini-batch for Worker.
parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
# client_learning_rate is also used as the learning rate for Worker.
parser.add_argument("--client_learning_rate", type=float, default=0.01)
parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01)
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
parser.add_argument("--dataset_path", type=str, default="")
# The user_id is used to set each worker's dataset path.
parser.add_argument("--user_id", type=str, default="0")
parser.add_argument('--img_size', type=int, default=(32, 32, 1), help='the image size of (h,w,c)')
parser.add_argument('--batch_size', type=float, default=32, help='batch size')
parser.add_argument('--repeat_size', type=int, default=1, help='the repeat size when create the dataLoader')
args, _ = parser.parse_known_args()
@ -90,12 +87,6 @@ worker_step_num_per_iteration = args.worker_step_num_per_iteration
scheduler_manage_port = args.scheduler_manage_port
config_file_path = args.config_file_path
encrypt_type = args.encrypt_type
share_secrets_ratio = args.share_secrets_ratio
cipher_time_window = args.cipher_time_window
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
dp_eps = args.dp_eps
dp_delta = args.dp_delta
dp_norm_clip = args.dp_norm_clip
dataset_path = args.dataset_path
user_id = args.user_id
@ -120,12 +111,6 @@ ctx = {
"worker_step_num_per_iteration": worker_step_num_per_iteration,
"scheduler_manage_port": scheduler_manage_port,
"config_file_path": config_file_path,
"share_secrets_ratio": share_secrets_ratio,
"cipher_time_window": cipher_time_window,
"reconstruct_secrets_threshold": reconstruct_secrets_threshold,
"dp_eps": dp_eps,
"dp_delta": dp_delta,
"dp_norm_clip": dp_norm_clip,
"encrypt_type": encrypt_type
}
@ -314,13 +299,13 @@ class UpdateAndGetModel(nn.Cell):
def train():
epoch = client_epoch_num
epoch = fl_iteration_num
network = LeNet5(62, 3)
# define the loss function
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define the optimizer
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
net_opt = nn.Momentum(network.trainable_params(), client_learning_rate, 0.9)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy(), 'Loss': nn.Loss()})
ds.config.set_seed(1)
@ -329,7 +314,7 @@ def train():
train_path = os.path.join(data_root_path, user, "train")
test_path = os.path.join(data_root_path, user, "test")
dataset = create_dataset_from_folder(train_path, args.img_size, args.batch_size, args.repeat_size)
dataset = create_dataset_from_folder(train_path, args.img_size, args.client_batch_size, args.repeat_size)
print("size is ", dataset.get_dataset_size(), flush=True)
num_batches = dataset.get_dataset_size()
@ -341,7 +326,7 @@ def train():
for iter_num in range(fl_iteration_num):
if context.get_fl_context("ms_role") == "MS_WORKER":
start_fl_job = StartFLJob(dataset.get_dataset_size() * args.batch_size)
start_fl_job = StartFLJob(dataset.get_dataset_size() * args.client_batch_size)
start_fl_job()
for _ in range(epoch):
@ -356,8 +341,8 @@ def train():
ckpt_name = os.path.join(ckpt_path, ckpt_name)
save_checkpoint(network, ckpt_name)
train_acc, _ = evalute_process(model, train_path, args.img_size, args.batch_size)
test_acc, _ = evalute_process(model, test_path, args.img_size, args.batch_size)
train_acc, _ = evalute_process(model, train_path, args.img_size, args.client_batch_size)
test_acc, _ = evalute_process(model, test_path, args.img_size, args.client_batch_size)
loss_list = loss_cb.get_loss()
loss = sum(loss_list) / len(loss_list)
print('local epoch: {}, loss: {}, trian acc: {}, test acc: {}'.format(iter_num, loss, train_acc, test_acc),

View File

@ -32,18 +32,10 @@ parser.add_argument("--fl_name", type=str, default="Lenet")
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--client_learning_rate", type=float, default=0.01)
parser.add_argument("--local_server_num", type=int, default=-1)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
# parameters for encrypt_type='DP_ENCRYPT'
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
# parameters for encrypt_type='PW_ENCRYPT'
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -65,12 +57,6 @@ client_learning_rate = args.client_learning_rate
local_server_num = args.local_server_num
config_file_path = args.config_file_path
encrypt_type = args.encrypt_type
share_secrets_ratio = args.share_secrets_ratio
cipher_time_window = args.cipher_time_window
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
dp_eps = args.dp_eps
dp_delta = args.dp_delta
dp_norm_clip = args.dp_norm_clip
if local_server_num == -1:
local_server_num = server_num
@ -102,12 +88,6 @@ for i in range(local_server_num):
cmd_server += " --client_batch_size=" + str(client_batch_size)
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
cmd_server += " --encrypt_type=" + str(encrypt_type)
cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio)
cmd_server += " --cipher_time_window=" + str(cipher_time_window)
cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold)
cmd_server += " --dp_eps=" + str(dp_eps)
cmd_server += " --dp_delta=" + str(dp_delta)
cmd_server += " --dp_norm_clip=" + str(dp_norm_clip)
cmd_server += " > server.log 2>&1 &"
import time

View File

@ -25,6 +25,8 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.01)
parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
parser.add_argument("--local_worker_num", type=int, default=-1)
parser.add_argument("--config_file_path", type=str, default="")
@ -38,6 +40,8 @@ scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
client_batch_size = args.client_batch_size
client_learning_rate = args.client_learning_rate
worker_step_num_per_iteration = args.worker_step_num_per_iteration
local_worker_num = args.local_worker_num
config_file_path = args.config_file_path
@ -64,6 +68,8 @@ for i in range(local_worker_num):
cmd_worker += " --config_file_path=" + str(config_file_path)
cmd_worker += " --fl_iteration_num=" + str(fl_iteration_num)
cmd_worker += " --client_epoch_num=" + str(client_epoch_num)
cmd_worker += " --client_batch_size=" + str(client_batch_size)
cmd_worker += " --client_learning_rate=" + str(client_learning_rate)
cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration)
cmd_worker += " > worker.log 2>&1 &"

View File

@ -14,6 +14,7 @@
# ============================================================================
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.initializer import TruncatedNormal
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
@ -70,3 +71,25 @@ class LeNet5(nn.Cell):
x = self.relu(x)
x = self.fc3(x)
return x
class StartFLJob(nn.Cell):
def __init__(self, data_size):
super(StartFLJob, self).__init__()
self.start_fl_job = P.StartFLJob(data_size)
def construct(self):
return self.start_fl_job()
class UpdateAndGetModel(nn.Cell):
def __init__(self, weights):
super(UpdateAndGetModel, self).__init__()
self.update_model = P.UpdateModel()
self.get_model = P.GetModel()
self.weights = weights
def construct(self):
self.update_model(self.weights)
get_model = self.get_model(self.weights)
return get_model

View File

@ -19,9 +19,8 @@ import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.nn import WithLossCell
from src.cell_wrapper import TrainOneStepCellForFLWorker
from src.model import LeNet5
from mindspore.nn import WithLossCell, TrainOneStepCell
from src.model import LeNet5, StartFLJob, UpdateAndGetModel
parser = argparse.ArgumentParser(description="test_cross_silo_lenet")
parser.add_argument("--device_target", type=str, default="GPU")
@ -37,22 +36,17 @@ parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
parser.add_argument("--update_model_ratio", type=float, default=1.0)
parser.add_argument("--update_model_time_window", type=int, default=3000)
parser.add_argument("--fl_name", type=str, default="Lenet")
# fl_iteration_num is also used as the global epoch number for Worker.
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
# client_batch_size is also used as the batch size of each mini-batch for Worker.
parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
# client_learning_rate is also used as the learning rate for Worker.
parser.add_argument("--client_learning_rate", type=float, default=0.01)
parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
# parameters for encrypt_type='DP_ENCRYPT'
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
# parameters for encrypt_type='PW_ENCRYPT'
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -76,12 +70,6 @@ worker_step_num_per_iteration = args.worker_step_num_per_iteration
scheduler_manage_port = args.scheduler_manage_port
config_file_path = args.config_file_path
encrypt_type = args.encrypt_type
share_secrets_ratio = args.share_secrets_ratio
cipher_time_window = args.cipher_time_window
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
dp_eps = args.dp_eps
dp_delta = args.dp_delta
dp_norm_clip = args.dp_norm_clip
ctx = {
"enable_fl": True,
@ -104,12 +92,6 @@ ctx = {
"worker_step_num_per_iteration": worker_step_num_per_iteration,
"scheduler_manage_port": scheduler_manage_port,
"config_file_path": config_file_path,
"share_secrets_ratio": share_secrets_ratio,
"cipher_time_window": cipher_time_window,
"reconstruct_secrets_threshold": reconstruct_secrets_threshold,
"dp_eps": dp_eps,
"dp_delta": dp_delta,
"dp_norm_clip": dp_norm_clip,
"encrypt_type": encrypt_type
}
@ -117,19 +99,27 @@ context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_g
context.set_fl_context(**ctx)
if __name__ == "__main__":
epoch = 50000
epoch = fl_iteration_num
np.random.seed(0)
network = LeNet5(62)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
net_opt = nn.Momentum(network.trainable_params(), client_learning_rate, 0.9)
net_with_criterion = WithLossCell(network, criterion)
train_network = TrainOneStepCellForFLWorker(net_with_criterion, net_opt)
train_network = TrainOneStepCell(net_with_criterion, net_opt)
train_network.set_train()
losses = []
for _ in range(epoch):
data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32))
label = Tensor(np.random.randint(0, 61, (32)).astype(np.int32))
if context.get_fl_context("ms_role") == "MS_WORKER":
start_fl_job = StartFLJob(dataset.get_dataset_size() * args.client_batch_size)
start_fl_job()
data = Tensor(np.random.rand(client_batch_size, 3, 32, 32).astype(np.float32))
label = Tensor(np.random.randint(0, 61, (client_batch_size)).astype(np.int32))
loss = train_network(data, label).asnumpy()
losses.append(loss)
if context.get_fl_context("ms_role") == "MS_WORKER":
update_and_get_model = UpdateAndGetModel(net_opt.parameters)
update_and_get_model()
print(losses)