!20812 Add code for cross-silo with real data femnist and coco

Merge pull request !20812 from ZPaC/cross-silo
This commit is contained in:
i-robot 2021-07-26 13:21:19 +00:00 committed by Gitee
commit 34e2581dfc
11 changed files with 762 additions and 5 deletions

View File

@ -5,6 +5,7 @@ if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/aggregation_kernel_factory.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/fed_avg_kernel.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/sgd_kernel.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/optimizer_kernel_factory.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/round_kernel_factory.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/round_kernel.cc")

View File

@ -111,6 +111,7 @@ constexpr auto kAdamEps = "eps";
constexpr auto kFtrlLinear = "linear";
constexpr auto kDataSize = "data_size";
constexpr auto kNewDataSize = "new_data_size";
constexpr auto kStat = "stat";
// OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
// launched.
@ -155,11 +156,14 @@ const OptimParamNameToIndex kAdamWeightDecayNameToIdx = {{"inputs",
{"weight_decay", 7},
{"grad", 8}}},
{"outputs", {}}};
const std::map<std::string, OptimParamNameToIndex> kNameToIdxMap = {{kApplyMomentumOpName, kMomentumNameToIdx},
{kFusedSparseAdamName, kSparseAdamNameToIdx},
{kSparseApplyFtrlOpName, kSparseFtrlNameToIdx},
{kApplyAdamOpName, kAdamNameToIdx},
{"AdamWeightDecay", kAdamWeightDecayNameToIdx}};
const OptimParamNameToIndex kSGDNameToIdx = {
{"inputs", {{kWeight, 0}, {kGradient, 1}, {kLearningRate, 2}, {kAccumulation, 3}, {kMomentum, 4}, {kStat, 5}}},
{"outputs", {}}};
const std::map<std::string, OptimParamNameToIndex> kNameToIdxMap = {
{kApplyMomentumOpName, kMomentumNameToIdx}, {kFusedSparseAdamName, kSparseAdamNameToIdx},
{kSparseApplyFtrlOpName, kSparseFtrlNameToIdx}, {kApplyAdamOpName, kAdamNameToIdx},
{"AdamWeightDecay", kAdamWeightDecayNameToIdx}, {kSGDName, kSGDNameToIdx}};
constexpr uint32_t kLeaderServerRank = 0;
constexpr size_t kWorkerMgrThreadPoolSize = 32;

View File

@ -0,0 +1,35 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fl/server/kernel/sgd_kernel.h"
namespace mindspore {
namespace fl {
namespace server {
namespace kernel {
REG_OPTIMIZER_KERNEL(SGD,
ParamsInfo()
.AddInputNameType(kWeight, kNumberTypeFloat32)
.AddInputNameType(kGradient, kNumberTypeFloat32)
.AddInputNameType(kLearningRate, kNumberTypeFloat32)
.AddInputNameType(kAccumulation, kNumberTypeFloat32)
.AddInputNameType(kMomentum, kNumberTypeFloat32)
.AddInputNameType(kStat, kNumberTypeFloat32),
SGDKernel, float)
} // namespace kernel
} // namespace server
} // namespace fl
} // namespace mindspore

View File

@ -0,0 +1,63 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_SGD_KERNEL_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_SGD_KERNEL_H_
#include <vector>
#include <memory>
#include <utility>
#include "backend/kernel_compiler/cpu/sgd_cpu_kernel.h"
#include "fl/server/kernel/optimizer_kernel.h"
#include "fl/server/kernel/optimizer_kernel_factory.h"
namespace mindspore {
namespace fl {
namespace server {
namespace kernel {
using mindspore::kernel::SGDCPUKernel;
template <typename T>
class SGDKernel : public SGDCPUKernel<T>, public OptimizerKernel {
public:
SGDKernel() = default;
~SGDKernel() override = default;
void InitKernel(const CNodePtr &cnode) override {
SGDCPUKernel<T>::InitKernel(cnode);
InitServerKernelInputOutputSize(cnode);
GenerateReuseKernelNodeInfo();
}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return SGDCPUKernel<T>::Launch(inputs, workspace, outputs);
}
void GenerateReuseKernelNodeInfo() override {
MS_LOG(INFO) << "SGD reuse 'weight', 'learning rate', 'accumulation', 'momentum' and 'stat' of the kernel node.";
reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, 0));
reuse_kernel_node_inputs_info_.insert(std::make_pair(kLearningRate, 2));
reuse_kernel_node_inputs_info_.insert(std::make_pair(kAccumulation, 3));
reuse_kernel_node_inputs_info_.insert(std::make_pair(kMomentum, 4));
reuse_kernel_node_inputs_info_.insert(std::make_pair(kStat, 5));
return;
}
};
} // namespace kernel
} // namespace server
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_SGD_KERNEL_H_

View File

@ -0,0 +1,6 @@
{
"recovery": {
"storge_type": 1,
"storage_file_path": "recovery.json"
}
}

View File

@ -0,0 +1,29 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import subprocess
parser = argparse.ArgumentParser(description="Finish test_cross_silo_femnist.py case")
parser.add_argument("--scheduler_port", type=int, default=8113)
args, _ = parser.parse_known_args()
scheduler_port = args.scheduler_port
cmd = "pid=`ps -ef|grep \"scheduler_port=" + str(scheduler_port) + "\" "
cmd += " | grep -v \"grep\" | grep -v \"finish\" |awk '{print $2}'` && "
cmd += "for id in $pid; do kill -9 $id && echo \"killed $id\"; done"
subprocess.call(['bash', '-c', cmd])

View File

@ -0,0 +1,58 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import subprocess
parser = argparse.ArgumentParser(description="Run test_cross_silo_femnist.py case")
parser.add_argument("--device_target", type=str, default="CPU")
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
parser.add_argument("--worker_num", type=int, default=1)
parser.add_argument("--server_num", type=int, default=2)
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--dataset_path", type=str, default="")
args, _ = parser.parse_known_args()
device_target = args.device_target
server_mode = args.server_mode
worker_num = args.worker_num
server_num = args.server_num
scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
scheduler_manage_port = args.scheduler_manage_port
config_file_path = args.config_file_path
dataset_path = args.dataset_path
cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&"
cmd_sched += "mkdir ${execute_path}/scheduler/ &&"
cmd_sched += "cd ${execute_path}/scheduler/ || exit && export GLOG_v=1 &&"
cmd_sched += "python ${self_path}/../test_cross_silo_femnist.py"
cmd_sched += " --device_target=" + device_target
cmd_sched += " --server_mode=" + server_mode
cmd_sched += " --ms_role=MS_SCHED"
cmd_sched += " --worker_num=" + str(worker_num)
cmd_sched += " --server_num=" + str(server_num)
cmd_sched += " --config_file_path=" + str(config_file_path)
cmd_sched += " --scheduler_ip=" + scheduler_ip
cmd_sched += " --scheduler_port=" + str(scheduler_port)
cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port)
cmd_sched += " --dataset_path=" + str(dataset_path)
cmd_sched += " --user_id=" + str(0)
cmd_sched += " > scheduler.log 2>&1 &"
subprocess.call(['bash', '-c', cmd_sched])

View File

@ -0,0 +1,119 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import subprocess
parser = argparse.ArgumentParser(description="Run test_cross_silo_femnist.py case")
parser.add_argument("--device_target", type=str, default="CPU")
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
parser.add_argument("--worker_num", type=int, default=1)
parser.add_argument("--server_num", type=int, default=2)
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--fl_server_port", type=int, default=6666)
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
parser.add_argument("--update_model_ratio", type=float, default=1.0)
parser.add_argument("--update_model_time_window", type=int, default=3000)
parser.add_argument("--fl_name", type=str, default="Lenet")
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--local_server_num", type=int, default=-1)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
# parameters for encrypt_type='DP_ENCRYPT'
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
# parameters for encrypt_type='PW_ENCRYPT'
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
parser.add_argument("--dataset_path", type=str, default="")
args, _ = parser.parse_known_args()
device_target = args.device_target
server_mode = args.server_mode
worker_num = args.worker_num
server_num = args.server_num
scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
fl_server_port = args.fl_server_port
start_fl_job_threshold = args.start_fl_job_threshold
start_fl_job_time_window = args.start_fl_job_time_window
update_model_ratio = args.update_model_ratio
update_model_time_window = args.update_model_time_window
fl_name = args.fl_name
fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
client_batch_size = args.client_batch_size
client_learning_rate = args.client_learning_rate
local_server_num = args.local_server_num
config_file_path = args.config_file_path
encrypt_type = args.encrypt_type
share_secrets_ratio = args.share_secrets_ratio
cipher_time_window = args.cipher_time_window
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
dp_eps = args.dp_eps
dp_delta = args.dp_delta
dp_norm_clip = args.dp_norm_clip
dataset_path = args.dataset_path
if local_server_num == -1:
local_server_num = server_num
assert local_server_num <= server_num, "The local server number should not be bigger than total server number."
for i in range(local_server_num):
cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && "
cmd_server += "rm -rf ${execute_path}/server_" + str(i) + "/ &&"
cmd_server += "mkdir ${execute_path}/server_" + str(i) + "/ &&"
cmd_server += "cd ${execute_path}/server_" + str(i) + "/ || exit && export GLOG_v=1 &&"
cmd_server += "python ${self_path}/../test_cross_silo_femnist.py"
cmd_server += " --device_target=" + device_target
cmd_server += " --server_mode=" + server_mode
cmd_server += " --ms_role=MS_SERVER"
cmd_server += " --worker_num=" + str(worker_num)
cmd_server += " --server_num=" + str(server_num)
cmd_server += " --scheduler_ip=" + scheduler_ip
cmd_server += " --scheduler_port=" + str(scheduler_port)
cmd_server += " --fl_server_port=" + str(fl_server_port + i)
cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold)
cmd_server += " --start_fl_job_time_window=" + str(start_fl_job_time_window)
cmd_server += " --update_model_ratio=" + str(update_model_ratio)
cmd_server += " --update_model_time_window=" + str(update_model_time_window)
cmd_server += " --fl_name=" + fl_name
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)
cmd_server += " --config_file_path=" + str(config_file_path)
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
cmd_server += " --client_batch_size=" + str(client_batch_size)
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
cmd_server += " --encrypt_type=" + str(encrypt_type)
cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio)
cmd_server += " --cipher_time_window=" + str(cipher_time_window)
cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold)
cmd_server += " --dp_eps=" + str(dp_eps)
cmd_server += " --dp_delta=" + str(dp_delta)
cmd_server += " --dp_norm_clip=" + str(dp_norm_clip)
cmd_server += " --dataset_path=" + str(dataset_path)
cmd_server += " --user_id=" + str(0)
cmd_server += " > server.log 2>&1 &"
import time
time.sleep(0.3)
subprocess.call(['bash', '-c', cmd_server])

View File

@ -0,0 +1,73 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import subprocess
parser = argparse.ArgumentParser(description="Run test_cross_silo_femnist.py case")
parser.add_argument("--device_target", type=str, default="GPU")
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
parser.add_argument("--worker_num", type=int, default=1)
parser.add_argument("--server_num", type=int, default=2)
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
parser.add_argument("--local_worker_num", type=int, default=-1)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--dataset_path", type=str, default="")
args, _ = parser.parse_known_args()
device_target = args.device_target
server_mode = args.server_mode
worker_num = args.worker_num
server_num = args.server_num
scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
worker_step_num_per_iteration = args.worker_step_num_per_iteration
local_worker_num = args.local_worker_num
config_file_path = args.config_file_path
dataset_path = args.dataset_path
if local_worker_num == -1:
local_worker_num = worker_num
assert local_worker_num <= worker_num, "The local worker number should not be bigger than total worker number."
for i in range(local_worker_num):
cmd_worker = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && "
cmd_worker += "rm -rf ${execute_path}/worker_" + str(i) + "/ &&"
cmd_worker += "mkdir ${execute_path}/worker_" + str(i) + "/ &&"
cmd_worker += "cd ${execute_path}/worker_" + str(i) + "/ || exit && export GLOG_v=1 &&"
cmd_worker += "python ${self_path}/../test_cross_silo_femnist.py"
cmd_worker += " --device_target=" + device_target
cmd_worker += " --server_mode=" + server_mode
cmd_worker += " --ms_role=MS_WORKER"
cmd_worker += " --worker_num=" + str(worker_num)
cmd_worker += " --server_num=" + str(server_num)
cmd_worker += " --scheduler_ip=" + scheduler_ip
cmd_worker += " --scheduler_port=" + str(scheduler_port)
cmd_worker += " --config_file_path=" + str(config_file_path)
cmd_worker += " --fl_iteration_num=" + str(fl_iteration_num)
cmd_worker += " --client_epoch_num=" + str(client_epoch_num)
cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration)
cmd_worker += " --dataset_path=" + str(dataset_path)
cmd_worker += " --user_id=" + str(i)
cmd_worker += " > worker.log 2>&1 &"
subprocess.call(['bash', '-c', cmd_worker])

View File

@ -0,0 +1,368 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import os
import time
import numpy as np
import mindspore
import mindspore.context as context
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
from mindspore import Tensor
import mindspore.dataset as ds
import mindspore.dataset.vision.py_transforms as PV
import mindspore.dataset.transforms.py_transforms as PT
import mindspore.dataset.transforms.c_transforms as tC
from mindspore.train.serialization import save_checkpoint
from mindspore.ops import operations as P
from mindspore.train.callback import Callback
from mindspore.nn.metrics import Accuracy
from mindspore.train import Model
parser = argparse.ArgumentParser(description="test_cross_silo_femnist")
parser.add_argument("--device_target", type=str, default="CPU")
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
parser.add_argument("--ms_role", type=str, default="MS_WORKER")
parser.add_argument("--worker_num", type=int, default=1)
parser.add_argument("--server_num", type=int, default=1)
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--fl_server_port", type=int, default=6666)
parser.add_argument("--start_fl_job_threshold", type=int, default=1)
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
parser.add_argument("--update_model_ratio", type=float, default=1.0)
parser.add_argument("--update_model_time_window", type=int, default=3000)
parser.add_argument("--fl_name", type=str, default="Lenet")
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01)
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3)
parser.add_argument("--dataset_path", type=str, default="")
parser.add_argument("--user_id", type=str, default="0")
parser.add_argument('--img_size', type=int, default=(32, 32, 1), help='the image size of (h,w,c)')
parser.add_argument('--batch_size', type=float, default=32, help='batch size')
parser.add_argument('--repeat_size', type=int, default=1, help='the repeat size when create the dataLoader')
args, _ = parser.parse_known_args()
device_target = args.device_target
server_mode = args.server_mode
ms_role = args.ms_role
worker_num = args.worker_num
server_num = args.server_num
scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
fl_server_port = args.fl_server_port
start_fl_job_threshold = args.start_fl_job_threshold
start_fl_job_time_window = args.start_fl_job_time_window
update_model_ratio = args.update_model_ratio
update_model_time_window = args.update_model_time_window
fl_name = args.fl_name
fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
client_batch_size = args.client_batch_size
client_learning_rate = args.client_learning_rate
worker_step_num_per_iteration = args.worker_step_num_per_iteration
scheduler_manage_port = args.scheduler_manage_port
config_file_path = args.config_file_path
encrypt_type = args.encrypt_type
share_secrets_ratio = args.share_secrets_ratio
cipher_time_window = args.cipher_time_window
reconstruct_secrets_threshold = args.reconstruct_secrets_threshold
dp_eps = args.dp_eps
dp_delta = args.dp_delta
dp_norm_clip = args.dp_norm_clip
dataset_path = args.dataset_path
user_id = args.user_id
ctx = {
"enable_fl": True,
"server_mode": server_mode,
"ms_role": ms_role,
"worker_num": worker_num,
"server_num": server_num,
"scheduler_ip": scheduler_ip,
"scheduler_port": scheduler_port,
"fl_server_port": fl_server_port,
"start_fl_job_threshold": start_fl_job_threshold,
"start_fl_job_time_window": start_fl_job_time_window,
"update_model_ratio": update_model_ratio,
"update_model_time_window": update_model_time_window,
"fl_name": fl_name,
"fl_iteration_num": fl_iteration_num,
"client_epoch_num": client_epoch_num,
"client_batch_size": client_batch_size,
"client_learning_rate": client_learning_rate,
"worker_step_num_per_iteration": worker_step_num_per_iteration,
"scheduler_manage_port": scheduler_manage_port,
"config_file_path": config_file_path,
"share_secrets_ratio": share_secrets_ratio,
"cipher_time_window": cipher_time_window,
"reconstruct_secrets_threshold": reconstruct_secrets_threshold,
"dp_eps": dp_eps,
"dp_delta": dp_delta,
"dp_norm_clip": dp_norm_clip,
"encrypt_type": encrypt_type
}
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)
context.set_fl_context(**ctx)
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
weight_init=weight,
has_bias=False,
pad_mode="valid",
)
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
class LeNet5(nn.Cell):
def __init__(self, num_class=10, channel=3):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = conv(channel, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
class LossGet(Callback):
# define loss callback for packaged model
def __init__(self, per_print_times, data_size):
super(LossGet, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self._loss = 0.0
self.data_size = data_size
self.loss_list = []
def step_end(self, run_context):
cb_params = run_context.original_args()
loss = cb_params.net_outputs
if isinstance(loss, (tuple, list)):
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
loss = loss[0]
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
loss = np.mean(loss.asnumpy())
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training."
.format(cb_params.cur_epoch_num, cur_step_in_epoch))
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
self._loss = loss
self.loss_list.append(loss)
def epoch_begin(self, run_context):
self.epoch_time = time.time()
def epoch_end(self, run_context):
epoch_mseconds = (time.time() - self.epoch_time) * 1000
self._per_step_mseconds = epoch_mseconds / self.data_size
def get_loss(self):
return self.loss_list # todo return self._loss
def get_per_step_time(self):
return self._per_step_mseconds
def mkdir(path):
if not os.path.exists(path):
os.mkdir(path)
def count_id(path):
files = os.listdir(path)
ids = {}
for i in files:
ids[i] = int(i)
return ids
def create_dataset_from_folder(data_path, img_size, batch_size=32, repeat_size=1, num_parallel_workers=1,
shuffle=False):
""" create dataset for train or test
Args:
data_path: Data path
batch_size: The number of data records in each group
repeat_size: The number of replicated data records
num_parallel_workers: The number of parallel workers
"""
# define dataset
ids = count_id(data_path)
mnist_ds = ds.ImageFolderDataset(dataset_dir=data_path, decode=False, class_indexing=ids)
# define operation parameters
resize_height, resize_width = img_size[0], img_size[1]
transform = [
PV.Decode(),
PV.Grayscale(1),
PV.Resize(size=(resize_height, resize_width)),
PV.Grayscale(3),
PV.ToTensor()
]
compose = PT.Compose(transform)
# apply map operations on images
mnist_ds = mnist_ds.map(input_columns="label", operations=tC.TypeCast(mindspore.int32))
mnist_ds = mnist_ds.map(input_columns="image", operations=compose)
# apply DatasetOps
buffer_size = 10000
if shuffle:
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds
def evalute_process(model, eval_data, img_size, batch_size):
"""Define the evaluation method."""
ds_eval = create_dataset_from_folder(eval_data, img_size, batch_size)
acc = model.eval(ds_eval, dataset_sink_mode=False)
return acc['Accuracy'], acc['Loss']
class StartFLJob(nn.Cell):
def __init__(self, data_size):
super(StartFLJob, self).__init__()
self.start_fl_job = P.StartFLJob(data_size)
def construct(self):
return self.start_fl_job()
class UpdateAndGetModel(nn.Cell):
def __init__(self, weights):
super(UpdateAndGetModel, self).__init__()
self.update_model = P.UpdateModel()
self.get_model = P.GetModel()
self.weights = weights
def construct(self):
self.update_model(self.weights)
get_model = self.get_model(self.weights)
return get_model
def train():
epoch = client_epoch_num
network = LeNet5(62, 3)
# define the loss function
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define the optimizer
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy(), 'Loss': nn.Loss()})
ds.config.set_seed(1)
data_root_path = dataset_path
user = "dataset_" + user_id
train_path = os.path.join(data_root_path, user, "train")
test_path = os.path.join(data_root_path, user, "test")
dataset = create_dataset_from_folder(train_path, args.img_size, args.batch_size, args.repeat_size)
print("size is ", dataset.get_dataset_size(), flush=True)
num_batches = dataset.get_dataset_size()
loss_cb = LossGet(1, num_batches)
cbs = []
cbs.append(loss_cb)
ckpt_path = "ckpt"
os.makedirs(ckpt_path)
for iter_num in range(fl_iteration_num):
if context.get_fl_context("ms_role") == "MS_WORKER":
start_fl_job = StartFLJob(dataset.get_dataset_size() * args.batch_size)
start_fl_job()
for _ in range(epoch):
print("step is ", epoch, flush=True)
model.train(1, dataset, callbacks=cbs, dataset_sink_mode=False)
if context.get_fl_context("ms_role") == "MS_WORKER":
update_and_get_model = UpdateAndGetModel(net_opt.parameters)
update_and_get_model()
ckpt_name = user_id + "-fl-ms-bs32-" + str(iter_num) + "epoch.ckpt"
ckpt_name = os.path.join(ckpt_path, ckpt_name)
save_checkpoint(network, ckpt_name)
train_acc, _ = evalute_process(model, train_path, args.img_size, args.batch_size)
test_acc, _ = evalute_process(model, test_path, args.img_size, args.batch_size)
loss_list = loss_cb.get_loss()
loss = sum(loss_list) / len(loss_list)
print('local epoch: {}, loss: {}, trian acc: {}, test acc: {}'.format(iter_num, loss, train_acc, test_acc),
flush=True)
if __name__ == "__main__":
train()

View File

@ -176,6 +176,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/gpu/gp
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/fl/server/kernel/apply_momentum_kernel.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/fl/server/kernel/sgd_kernel.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/post_batch_norm_add_relu_fusion.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc")