diff --git a/mindspore/ccsrc/fl/CMakeLists.txt b/mindspore/ccsrc/fl/CMakeLists.txt index dbd8621642f..bab81a91bb4 100644 --- a/mindspore/ccsrc/fl/CMakeLists.txt +++ b/mindspore/ccsrc/fl/CMakeLists.txt @@ -5,6 +5,7 @@ if(NOT ENABLE_CPU OR WIN32) list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/aggregation_kernel_factory.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/fed_avg_kernel.cc") + list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/sgd_kernel.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/optimizer_kernel_factory.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/round_kernel_factory.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/round_kernel.cc") diff --git a/mindspore/ccsrc/fl/server/common.h b/mindspore/ccsrc/fl/server/common.h index 775e624d41b..87dae8a1207 100644 --- a/mindspore/ccsrc/fl/server/common.h +++ b/mindspore/ccsrc/fl/server/common.h @@ -111,6 +111,7 @@ constexpr auto kAdamEps = "eps"; constexpr auto kFtrlLinear = "linear"; constexpr auto kDataSize = "data_size"; constexpr auto kNewDataSize = "new_data_size"; +constexpr auto kStat = "stat"; // OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is // launched. @@ -155,11 +156,14 @@ const OptimParamNameToIndex kAdamWeightDecayNameToIdx = {{"inputs", {"weight_decay", 7}, {"grad", 8}}}, {"outputs", {}}}; -const std::map kNameToIdxMap = {{kApplyMomentumOpName, kMomentumNameToIdx}, - {kFusedSparseAdamName, kSparseAdamNameToIdx}, - {kSparseApplyFtrlOpName, kSparseFtrlNameToIdx}, - {kApplyAdamOpName, kAdamNameToIdx}, - {"AdamWeightDecay", kAdamWeightDecayNameToIdx}}; +const OptimParamNameToIndex kSGDNameToIdx = { + {"inputs", {{kWeight, 0}, {kGradient, 1}, {kLearningRate, 2}, {kAccumulation, 3}, {kMomentum, 4}, {kStat, 5}}}, + {"outputs", {}}}; + +const std::map kNameToIdxMap = { + {kApplyMomentumOpName, kMomentumNameToIdx}, {kFusedSparseAdamName, kSparseAdamNameToIdx}, + {kSparseApplyFtrlOpName, kSparseFtrlNameToIdx}, {kApplyAdamOpName, kAdamNameToIdx}, + {"AdamWeightDecay", kAdamWeightDecayNameToIdx}, {kSGDName, kSGDNameToIdx}}; constexpr uint32_t kLeaderServerRank = 0; constexpr size_t kWorkerMgrThreadPoolSize = 32; diff --git a/mindspore/ccsrc/fl/server/kernel/sgd_kernel.cc b/mindspore/ccsrc/fl/server/kernel/sgd_kernel.cc new file mode 100644 index 00000000000..0b1d13673c0 --- /dev/null +++ b/mindspore/ccsrc/fl/server/kernel/sgd_kernel.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fl/server/kernel/sgd_kernel.h" + +namespace mindspore { +namespace fl { +namespace server { +namespace kernel { +REG_OPTIMIZER_KERNEL(SGD, + ParamsInfo() + .AddInputNameType(kWeight, kNumberTypeFloat32) + .AddInputNameType(kGradient, kNumberTypeFloat32) + .AddInputNameType(kLearningRate, kNumberTypeFloat32) + .AddInputNameType(kAccumulation, kNumberTypeFloat32) + .AddInputNameType(kMomentum, kNumberTypeFloat32) + .AddInputNameType(kStat, kNumberTypeFloat32), + SGDKernel, float) +} // namespace kernel +} // namespace server +} // namespace fl +} // namespace mindspore diff --git a/mindspore/ccsrc/fl/server/kernel/sgd_kernel.h b/mindspore/ccsrc/fl/server/kernel/sgd_kernel.h new file mode 100644 index 00000000000..233f9c0d8df --- /dev/null +++ b/mindspore/ccsrc/fl/server/kernel/sgd_kernel.h @@ -0,0 +1,63 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_SGD_KERNEL_H_ +#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_SGD_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/cpu/sgd_cpu_kernel.h" +#include "fl/server/kernel/optimizer_kernel.h" +#include "fl/server/kernel/optimizer_kernel_factory.h" + +namespace mindspore { +namespace fl { +namespace server { +namespace kernel { +using mindspore::kernel::SGDCPUKernel; +template +class SGDKernel : public SGDCPUKernel, public OptimizerKernel { + public: + SGDKernel() = default; + ~SGDKernel() override = default; + + void InitKernel(const CNodePtr &cnode) override { + SGDCPUKernel::InitKernel(cnode); + InitServerKernelInputOutputSize(cnode); + GenerateReuseKernelNodeInfo(); + } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + return SGDCPUKernel::Launch(inputs, workspace, outputs); + } + + void GenerateReuseKernelNodeInfo() override { + MS_LOG(INFO) << "SGD reuse 'weight', 'learning rate', 'accumulation', 'momentum' and 'stat' of the kernel node."; + reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, 0)); + reuse_kernel_node_inputs_info_.insert(std::make_pair(kLearningRate, 2)); + reuse_kernel_node_inputs_info_.insert(std::make_pair(kAccumulation, 3)); + reuse_kernel_node_inputs_info_.insert(std::make_pair(kMomentum, 4)); + reuse_kernel_node_inputs_info_.insert(std::make_pair(kStat, 5)); + return; + } +}; +} // namespace kernel +} // namespace server +} // namespace fl +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_SGD_KERNEL_H_ diff --git a/tests/st/fl/cross_silo_femnist/config.json b/tests/st/fl/cross_silo_femnist/config.json new file mode 100644 index 00000000000..37ac6edfb25 --- /dev/null +++ b/tests/st/fl/cross_silo_femnist/config.json @@ -0,0 +1,6 @@ +{ + "recovery": { + "storge_type": 1, + "storage_file_path": "recovery.json" + } +} \ No newline at end of file diff --git a/tests/st/fl/cross_silo_femnist/finish_cross_silo_femnist.py b/tests/st/fl/cross_silo_femnist/finish_cross_silo_femnist.py new file mode 100644 index 00000000000..2f48947835a --- /dev/null +++ b/tests/st/fl/cross_silo_femnist/finish_cross_silo_femnist.py @@ -0,0 +1,29 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import subprocess + +parser = argparse.ArgumentParser(description="Finish test_cross_silo_femnist.py case") +parser.add_argument("--scheduler_port", type=int, default=8113) + +args, _ = parser.parse_known_args() +scheduler_port = args.scheduler_port + +cmd = "pid=`ps -ef|grep \"scheduler_port=" + str(scheduler_port) + "\" " +cmd += " | grep -v \"grep\" | grep -v \"finish\" |awk '{print $2}'` && " +cmd += "for id in $pid; do kill -9 $id && echo \"killed $id\"; done" + +subprocess.call(['bash', '-c', cmd]) diff --git a/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_sched.py b/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_sched.py new file mode 100644 index 00000000000..477e72a4124 --- /dev/null +++ b/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_sched.py @@ -0,0 +1,58 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import subprocess + +parser = argparse.ArgumentParser(description="Run test_cross_silo_femnist.py case") +parser.add_argument("--device_target", type=str, default="CPU") +parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING") +parser.add_argument("--worker_num", type=int, default=1) +parser.add_argument("--server_num", type=int, default=2) +parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") +parser.add_argument("--scheduler_port", type=int, default=8113) +parser.add_argument("--scheduler_manage_port", type=int, default=11202) +parser.add_argument("--config_file_path", type=str, default="") +parser.add_argument("--dataset_path", type=str, default="") + +args, _ = parser.parse_known_args() +device_target = args.device_target +server_mode = args.server_mode +worker_num = args.worker_num +server_num = args.server_num +scheduler_ip = args.scheduler_ip +scheduler_port = args.scheduler_port +scheduler_manage_port = args.scheduler_manage_port +config_file_path = args.config_file_path +dataset_path = args.dataset_path + +cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&" +cmd_sched += "mkdir ${execute_path}/scheduler/ &&" +cmd_sched += "cd ${execute_path}/scheduler/ || exit && export GLOG_v=1 &&" +cmd_sched += "python ${self_path}/../test_cross_silo_femnist.py" +cmd_sched += " --device_target=" + device_target +cmd_sched += " --server_mode=" + server_mode +cmd_sched += " --ms_role=MS_SCHED" +cmd_sched += " --worker_num=" + str(worker_num) +cmd_sched += " --server_num=" + str(server_num) +cmd_sched += " --config_file_path=" + str(config_file_path) +cmd_sched += " --scheduler_ip=" + scheduler_ip +cmd_sched += " --scheduler_port=" + str(scheduler_port) +cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port) +cmd_sched += " --dataset_path=" + str(dataset_path) +cmd_sched += " --user_id=" + str(0) +cmd_sched += " > scheduler.log 2>&1 &" + +subprocess.call(['bash', '-c', cmd_sched]) diff --git a/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_server.py b/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_server.py new file mode 100644 index 00000000000..9869b4df4dd --- /dev/null +++ b/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_server.py @@ -0,0 +1,119 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import subprocess + +parser = argparse.ArgumentParser(description="Run test_cross_silo_femnist.py case") +parser.add_argument("--device_target", type=str, default="CPU") +parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING") +parser.add_argument("--worker_num", type=int, default=1) +parser.add_argument("--server_num", type=int, default=2) +parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") +parser.add_argument("--scheduler_port", type=int, default=8113) +parser.add_argument("--fl_server_port", type=int, default=6666) +parser.add_argument("--start_fl_job_threshold", type=int, default=1) +parser.add_argument("--start_fl_job_time_window", type=int, default=3000) +parser.add_argument("--update_model_ratio", type=float, default=1.0) +parser.add_argument("--update_model_time_window", type=int, default=3000) +parser.add_argument("--fl_name", type=str, default="Lenet") +parser.add_argument("--fl_iteration_num", type=int, default=25) +parser.add_argument("--client_epoch_num", type=int, default=20) +parser.add_argument("--client_batch_size", type=int, default=32) +parser.add_argument("--client_learning_rate", type=float, default=0.1) +parser.add_argument("--local_server_num", type=int, default=-1) +parser.add_argument("--config_file_path", type=str, default="") +parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") +# parameters for encrypt_type='DP_ENCRYPT' +parser.add_argument("--dp_eps", type=float, default=50.0) +parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num +parser.add_argument("--dp_norm_clip", type=float, default=1.0) +# parameters for encrypt_type='PW_ENCRYPT' +parser.add_argument("--share_secrets_ratio", type=float, default=1.0) +parser.add_argument("--cipher_time_window", type=int, default=300000) +parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) +parser.add_argument("--dataset_path", type=str, default="") + +args, _ = parser.parse_known_args() +device_target = args.device_target +server_mode = args.server_mode +worker_num = args.worker_num +server_num = args.server_num +scheduler_ip = args.scheduler_ip +scheduler_port = args.scheduler_port +fl_server_port = args.fl_server_port +start_fl_job_threshold = args.start_fl_job_threshold +start_fl_job_time_window = args.start_fl_job_time_window +update_model_ratio = args.update_model_ratio +update_model_time_window = args.update_model_time_window +fl_name = args.fl_name +fl_iteration_num = args.fl_iteration_num +client_epoch_num = args.client_epoch_num +client_batch_size = args.client_batch_size +client_learning_rate = args.client_learning_rate +local_server_num = args.local_server_num +config_file_path = args.config_file_path +encrypt_type = args.encrypt_type +share_secrets_ratio = args.share_secrets_ratio +cipher_time_window = args.cipher_time_window +reconstruct_secrets_threshold = args.reconstruct_secrets_threshold +dp_eps = args.dp_eps +dp_delta = args.dp_delta +dp_norm_clip = args.dp_norm_clip +dataset_path = args.dataset_path + +if local_server_num == -1: + local_server_num = server_num + +assert local_server_num <= server_num, "The local server number should not be bigger than total server number." + +for i in range(local_server_num): + cmd_server = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && " + cmd_server += "rm -rf ${execute_path}/server_" + str(i) + "/ &&" + cmd_server += "mkdir ${execute_path}/server_" + str(i) + "/ &&" + cmd_server += "cd ${execute_path}/server_" + str(i) + "/ || exit && export GLOG_v=1 &&" + cmd_server += "python ${self_path}/../test_cross_silo_femnist.py" + cmd_server += " --device_target=" + device_target + cmd_server += " --server_mode=" + server_mode + cmd_server += " --ms_role=MS_SERVER" + cmd_server += " --worker_num=" + str(worker_num) + cmd_server += " --server_num=" + str(server_num) + cmd_server += " --scheduler_ip=" + scheduler_ip + cmd_server += " --scheduler_port=" + str(scheduler_port) + cmd_server += " --fl_server_port=" + str(fl_server_port + i) + cmd_server += " --start_fl_job_threshold=" + str(start_fl_job_threshold) + cmd_server += " --start_fl_job_time_window=" + str(start_fl_job_time_window) + cmd_server += " --update_model_ratio=" + str(update_model_ratio) + cmd_server += " --update_model_time_window=" + str(update_model_time_window) + cmd_server += " --fl_name=" + fl_name + cmd_server += " --fl_iteration_num=" + str(fl_iteration_num) + cmd_server += " --config_file_path=" + str(config_file_path) + cmd_server += " --client_epoch_num=" + str(client_epoch_num) + cmd_server += " --client_batch_size=" + str(client_batch_size) + cmd_server += " --client_learning_rate=" + str(client_learning_rate) + cmd_server += " --encrypt_type=" + str(encrypt_type) + cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio) + cmd_server += " --cipher_time_window=" + str(cipher_time_window) + cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold) + cmd_server += " --dp_eps=" + str(dp_eps) + cmd_server += " --dp_delta=" + str(dp_delta) + cmd_server += " --dp_norm_clip=" + str(dp_norm_clip) + cmd_server += " --dataset_path=" + str(dataset_path) + cmd_server += " --user_id=" + str(0) + cmd_server += " > server.log 2>&1 &" + + import time + time.sleep(0.3) + subprocess.call(['bash', '-c', cmd_server]) diff --git a/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_worker.py b/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_worker.py new file mode 100644 index 00000000000..1be35aedb07 --- /dev/null +++ b/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_worker.py @@ -0,0 +1,73 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import subprocess + +parser = argparse.ArgumentParser(description="Run test_cross_silo_femnist.py case") +parser.add_argument("--device_target", type=str, default="GPU") +parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING") +parser.add_argument("--worker_num", type=int, default=1) +parser.add_argument("--server_num", type=int, default=2) +parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") +parser.add_argument("--scheduler_port", type=int, default=8113) +parser.add_argument("--fl_iteration_num", type=int, default=25) +parser.add_argument("--client_epoch_num", type=int, default=20) +parser.add_argument("--worker_step_num_per_iteration", type=int, default=65) +parser.add_argument("--local_worker_num", type=int, default=-1) +parser.add_argument("--config_file_path", type=str, default="") +parser.add_argument("--dataset_path", type=str, default="") + +args, _ = parser.parse_known_args() +device_target = args.device_target +server_mode = args.server_mode +worker_num = args.worker_num +server_num = args.server_num +scheduler_ip = args.scheduler_ip +scheduler_port = args.scheduler_port +fl_iteration_num = args.fl_iteration_num +client_epoch_num = args.client_epoch_num +worker_step_num_per_iteration = args.worker_step_num_per_iteration +local_worker_num = args.local_worker_num +config_file_path = args.config_file_path +dataset_path = args.dataset_path + +if local_worker_num == -1: + local_worker_num = worker_num + +assert local_worker_num <= worker_num, "The local worker number should not be bigger than total worker number." + +for i in range(local_worker_num): + cmd_worker = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && " + cmd_worker += "rm -rf ${execute_path}/worker_" + str(i) + "/ &&" + cmd_worker += "mkdir ${execute_path}/worker_" + str(i) + "/ &&" + cmd_worker += "cd ${execute_path}/worker_" + str(i) + "/ || exit && export GLOG_v=1 &&" + cmd_worker += "python ${self_path}/../test_cross_silo_femnist.py" + cmd_worker += " --device_target=" + device_target + cmd_worker += " --server_mode=" + server_mode + cmd_worker += " --ms_role=MS_WORKER" + cmd_worker += " --worker_num=" + str(worker_num) + cmd_worker += " --server_num=" + str(server_num) + cmd_worker += " --scheduler_ip=" + scheduler_ip + cmd_worker += " --scheduler_port=" + str(scheduler_port) + cmd_worker += " --config_file_path=" + str(config_file_path) + cmd_worker += " --fl_iteration_num=" + str(fl_iteration_num) + cmd_worker += " --client_epoch_num=" + str(client_epoch_num) + cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration) + cmd_worker += " --dataset_path=" + str(dataset_path) + cmd_worker += " --user_id=" + str(i) + cmd_worker += " > worker.log 2>&1 &" + + subprocess.call(['bash', '-c', cmd_worker]) diff --git a/tests/st/fl/cross_silo_femnist/test_cross_silo_femnist.py b/tests/st/fl/cross_silo_femnist/test_cross_silo_femnist.py new file mode 100644 index 00000000000..84ca569332e --- /dev/null +++ b/tests/st/fl/cross_silo_femnist/test_cross_silo_femnist.py @@ -0,0 +1,368 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import os +import time +import numpy as np + +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore.common.initializer import TruncatedNormal +from mindspore import Tensor +import mindspore.dataset as ds +import mindspore.dataset.vision.py_transforms as PV +import mindspore.dataset.transforms.py_transforms as PT +import mindspore.dataset.transforms.c_transforms as tC +from mindspore.train.serialization import save_checkpoint +from mindspore.ops import operations as P +from mindspore.train.callback import Callback +from mindspore.nn.metrics import Accuracy +from mindspore.train import Model + +parser = argparse.ArgumentParser(description="test_cross_silo_femnist") +parser.add_argument("--device_target", type=str, default="CPU") +parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING") +parser.add_argument("--ms_role", type=str, default="MS_WORKER") +parser.add_argument("--worker_num", type=int, default=1) +parser.add_argument("--server_num", type=int, default=1) +parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") +parser.add_argument("--scheduler_port", type=int, default=8113) +parser.add_argument("--fl_server_port", type=int, default=6666) +parser.add_argument("--start_fl_job_threshold", type=int, default=1) +parser.add_argument("--start_fl_job_time_window", type=int, default=3000) +parser.add_argument("--update_model_ratio", type=float, default=1.0) +parser.add_argument("--update_model_time_window", type=int, default=3000) +parser.add_argument("--fl_name", type=str, default="Lenet") +parser.add_argument("--fl_iteration_num", type=int, default=25) +parser.add_argument("--client_epoch_num", type=int, default=20) +parser.add_argument("--client_batch_size", type=int, default=32) +parser.add_argument("--client_learning_rate", type=float, default=0.1) +parser.add_argument("--worker_step_num_per_iteration", type=int, default=65) +parser.add_argument("--scheduler_manage_port", type=int, default=11202) +parser.add_argument("--config_file_path", type=str, default="") +parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") +parser.add_argument("--dp_eps", type=float, default=50.0) +parser.add_argument("--dp_delta", type=float, default=0.01) +parser.add_argument("--dp_norm_clip", type=float, default=1.0) +parser.add_argument("--share_secrets_ratio", type=float, default=1.0) +parser.add_argument("--cipher_time_window", type=int, default=300000) +parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) +parser.add_argument("--dataset_path", type=str, default="") +parser.add_argument("--user_id", type=str, default="0") + +parser.add_argument('--img_size', type=int, default=(32, 32, 1), help='the image size of (h,w,c)') +parser.add_argument('--batch_size', type=float, default=32, help='batch size') +parser.add_argument('--repeat_size', type=int, default=1, help='the repeat size when create the dataLoader') + +args, _ = parser.parse_known_args() +device_target = args.device_target +server_mode = args.server_mode +ms_role = args.ms_role +worker_num = args.worker_num +server_num = args.server_num +scheduler_ip = args.scheduler_ip +scheduler_port = args.scheduler_port +fl_server_port = args.fl_server_port +start_fl_job_threshold = args.start_fl_job_threshold +start_fl_job_time_window = args.start_fl_job_time_window +update_model_ratio = args.update_model_ratio +update_model_time_window = args.update_model_time_window +fl_name = args.fl_name +fl_iteration_num = args.fl_iteration_num +client_epoch_num = args.client_epoch_num +client_batch_size = args.client_batch_size +client_learning_rate = args.client_learning_rate +worker_step_num_per_iteration = args.worker_step_num_per_iteration +scheduler_manage_port = args.scheduler_manage_port +config_file_path = args.config_file_path +encrypt_type = args.encrypt_type +share_secrets_ratio = args.share_secrets_ratio +cipher_time_window = args.cipher_time_window +reconstruct_secrets_threshold = args.reconstruct_secrets_threshold +dp_eps = args.dp_eps +dp_delta = args.dp_delta +dp_norm_clip = args.dp_norm_clip +dataset_path = args.dataset_path +user_id = args.user_id + +ctx = { + "enable_fl": True, + "server_mode": server_mode, + "ms_role": ms_role, + "worker_num": worker_num, + "server_num": server_num, + "scheduler_ip": scheduler_ip, + "scheduler_port": scheduler_port, + "fl_server_port": fl_server_port, + "start_fl_job_threshold": start_fl_job_threshold, + "start_fl_job_time_window": start_fl_job_time_window, + "update_model_ratio": update_model_ratio, + "update_model_time_window": update_model_time_window, + "fl_name": fl_name, + "fl_iteration_num": fl_iteration_num, + "client_epoch_num": client_epoch_num, + "client_batch_size": client_batch_size, + "client_learning_rate": client_learning_rate, + "worker_step_num_per_iteration": worker_step_num_per_iteration, + "scheduler_manage_port": scheduler_manage_port, + "config_file_path": config_file_path, + "share_secrets_ratio": share_secrets_ratio, + "cipher_time_window": cipher_time_window, + "reconstruct_secrets_threshold": reconstruct_secrets_threshold, + "dp_eps": dp_eps, + "dp_delta": dp_delta, + "dp_norm_clip": dp_norm_clip, + "encrypt_type": encrypt_type +} + +context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False) +context.set_fl_context(**ctx) + + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + """weight initial for conv layer""" + weight = weight_variable() + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + weight_init=weight, + has_bias=False, + pad_mode="valid", + ) + + +def fc_with_initialize(input_channels, out_channels): + """weight initial for fc layer""" + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +def weight_variable(): + """weight initial""" + return TruncatedNormal(0.02) + + +class LeNet5(nn.Cell): + def __init__(self, num_class=10, channel=3): + super(LeNet5, self).__init__() + self.num_class = num_class + self.conv1 = conv(channel, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, self.num_class) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x + + +class LossGet(Callback): + # define loss callback for packaged model + def __init__(self, per_print_times, data_size): + super(LossGet, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0.") + self._per_print_times = per_print_times + self._loss = 0.0 + self.data_size = data_size + self.loss_list = [] + + def step_end(self, run_context): + cb_params = run_context.original_args() + loss = cb_params.net_outputs + + if isinstance(loss, (tuple, list)): + if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): + loss = loss[0] + + if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): + loss = np.mean(loss.asnumpy()) + + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + + if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): + raise ValueError("epoch: {} step: {}. Invalid loss, terminating training." + .format(cb_params.cur_epoch_num, cur_step_in_epoch)) + if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: + self._loss = loss + self.loss_list.append(loss) + + def epoch_begin(self, run_context): + self.epoch_time = time.time() + + def epoch_end(self, run_context): + epoch_mseconds = (time.time() - self.epoch_time) * 1000 + self._per_step_mseconds = epoch_mseconds / self.data_size + + def get_loss(self): + return self.loss_list # todo return self._loss + + def get_per_step_time(self): + return self._per_step_mseconds + + +def mkdir(path): + if not os.path.exists(path): + os.mkdir(path) + + +def count_id(path): + files = os.listdir(path) + ids = {} + for i in files: + ids[i] = int(i) + return ids + + +def create_dataset_from_folder(data_path, img_size, batch_size=32, repeat_size=1, num_parallel_workers=1, + shuffle=False): + """ create dataset for train or test + Args: + data_path: Data path + batch_size: The number of data records in each group + repeat_size: The number of replicated data records + num_parallel_workers: The number of parallel workers + """ + # define dataset + ids = count_id(data_path) + mnist_ds = ds.ImageFolderDataset(dataset_dir=data_path, decode=False, class_indexing=ids) + # define operation parameters + resize_height, resize_width = img_size[0], img_size[1] + + transform = [ + PV.Decode(), + PV.Grayscale(1), + PV.Resize(size=(resize_height, resize_width)), + PV.Grayscale(3), + PV.ToTensor() + ] + compose = PT.Compose(transform) + + # apply map operations on images + mnist_ds = mnist_ds.map(input_columns="label", operations=tC.TypeCast(mindspore.int32)) + mnist_ds = mnist_ds.map(input_columns="image", operations=compose) + + # apply DatasetOps + buffer_size = 10000 + if shuffle: + mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script + mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) + mnist_ds = mnist_ds.repeat(repeat_size) + return mnist_ds + + +def evalute_process(model, eval_data, img_size, batch_size): + """Define the evaluation method.""" + ds_eval = create_dataset_from_folder(eval_data, img_size, batch_size) + acc = model.eval(ds_eval, dataset_sink_mode=False) + return acc['Accuracy'], acc['Loss'] + + +class StartFLJob(nn.Cell): + def __init__(self, data_size): + super(StartFLJob, self).__init__() + self.start_fl_job = P.StartFLJob(data_size) + + def construct(self): + return self.start_fl_job() + + +class UpdateAndGetModel(nn.Cell): + def __init__(self, weights): + super(UpdateAndGetModel, self).__init__() + self.update_model = P.UpdateModel() + self.get_model = P.GetModel() + self.weights = weights + + def construct(self): + self.update_model(self.weights) + get_model = self.get_model(self.weights) + return get_model + + +def train(): + epoch = client_epoch_num + network = LeNet5(62, 3) + + # define the loss function + net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + # define the optimizer + net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) + model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy(), 'Loss': nn.Loss()}) + + ds.config.set_seed(1) + data_root_path = dataset_path + user = "dataset_" + user_id + train_path = os.path.join(data_root_path, user, "train") + test_path = os.path.join(data_root_path, user, "test") + + dataset = create_dataset_from_folder(train_path, args.img_size, args.batch_size, args.repeat_size) + print("size is ", dataset.get_dataset_size(), flush=True) + num_batches = dataset.get_dataset_size() + + loss_cb = LossGet(1, num_batches) + cbs = [] + cbs.append(loss_cb) + ckpt_path = "ckpt" + os.makedirs(ckpt_path) + + for iter_num in range(fl_iteration_num): + if context.get_fl_context("ms_role") == "MS_WORKER": + start_fl_job = StartFLJob(dataset.get_dataset_size() * args.batch_size) + start_fl_job() + + for _ in range(epoch): + print("step is ", epoch, flush=True) + model.train(1, dataset, callbacks=cbs, dataset_sink_mode=False) + + if context.get_fl_context("ms_role") == "MS_WORKER": + update_and_get_model = UpdateAndGetModel(net_opt.parameters) + update_and_get_model() + + ckpt_name = user_id + "-fl-ms-bs32-" + str(iter_num) + "epoch.ckpt" + ckpt_name = os.path.join(ckpt_path, ckpt_name) + save_checkpoint(network, ckpt_name) + + train_acc, _ = evalute_process(model, train_path, args.img_size, args.batch_size) + test_acc, _ = evalute_process(model, test_path, args.img_size, args.batch_size) + loss_list = loss_cb.get_loss() + loss = sum(loss_list) / len(loss_list) + print('local epoch: {}, loss: {}, trian acc: {}, test acc: {}'.format(iter_num, loss, train_acc, test_acc), + flush=True) + + +if __name__ == "__main__": + train() diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index a36a3e676a9..f8cebf1cc0e 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -176,6 +176,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/gpu/gp list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/fl/server/kernel/apply_momentum_kernel.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/fl/server/kernel/sgd_kernel.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/post_batch_norm_add_relu_fusion.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc")