diff --git a/tests/st/fl/albert/cloud_train.py b/tests/st/fl/albert/cloud_train.py index ba13ba28d4c..78f0fb80483 100644 --- a/tests/st/fl/albert/cloud_train.py +++ b/tests/st/fl/albert/cloud_train.py @@ -16,6 +16,7 @@ import argparse import os import sys +import ast from time import time import numpy as np from mindspore import context, Tensor @@ -66,6 +67,9 @@ def parse_args(): parser.add_argument("--share_secrets_ratio", type=float, default=1.0) parser.add_argument("--cipher_time_window", type=int, default=300000) parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) + parser.add_argument("--client_password", type=str, default="") + parser.add_argument("--server_password", type=str, default="") + parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) return parser.parse_args() @@ -100,6 +104,9 @@ def server_train(args): share_secrets_ratio = args.share_secrets_ratio cipher_time_window = args.cipher_time_window reconstruct_secrets_threshold = args.reconstruct_secrets_threshold + client_password = args.client_password + server_password = args.server_password + enable_ssl = args.enable_ssl # Replace some parameters with federated learning parameters. train_cfg.max_global_epoch = fl_iteration_num @@ -129,6 +136,9 @@ def server_train(args): "share_secrets_ratio": share_secrets_ratio, "cipher_time_window": cipher_time_window, "reconstruct_secrets_threshold": reconstruct_secrets_threshold, + "client_password": client_password, + "server_password": server_password, + "enable_ssl": enable_ssl } if not os.path.exists(output_dir): diff --git a/tests/st/fl/albert/run_hybrid_train_sched.py b/tests/st/fl/albert/run_hybrid_train_sched.py index d879e1f04a1..6fa9e2a0943 100644 --- a/tests/st/fl/albert/run_hybrid_train_sched.py +++ b/tests/st/fl/albert/run_hybrid_train_sched.py @@ -16,6 +16,7 @@ import argparse import subprocess import os +import ast parser = argparse.ArgumentParser(description="Run train_cloud.py case") parser.add_argument("--device_target", type=str, default="CPU") @@ -25,6 +26,9 @@ parser.add_argument("--server_num", type=int, default=2) parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") parser.add_argument("--scheduler_port", type=int, default=8113) parser.add_argument("--scheduler_manage_port", type=int, default=11202) +parser.add_argument("--client_password", type=str, default="") +parser.add_argument("--server_password", type=str, default="") +parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) args, _ = parser.parse_known_args() device_target = args.device_target @@ -34,6 +38,9 @@ server_num = args.server_num scheduler_ip = args.scheduler_ip scheduler_port = args.scheduler_port scheduler_manage_port = args.scheduler_manage_port +client_password = args.client_password +server_password = args.server_password +enable_ssl = args.enable_ssl os.environ['MS_NODE_ID'] = "20" cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&" @@ -47,6 +54,9 @@ cmd_sched += " --worker_num=" + str(worker_num) cmd_sched += " --server_num=" + str(server_num) cmd_sched += " --scheduler_ip=" + scheduler_ip cmd_sched += " --scheduler_port=" + str(scheduler_port) +cmd_sched += " --client_password=" + str(client_password) +cmd_sched += " --server_password=" + str(server_password) +cmd_sched += " --enable_ssl=" + str(enable_ssl) cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port) cmd_sched += " > scheduler.log 2>&1 &" diff --git a/tests/st/fl/albert/run_hybrid_train_server.py b/tests/st/fl/albert/run_hybrid_train_server.py index 3589a58707f..c300083cf77 100644 --- a/tests/st/fl/albert/run_hybrid_train_server.py +++ b/tests/st/fl/albert/run_hybrid_train_server.py @@ -16,6 +16,7 @@ import argparse import subprocess import os +import ast parser = argparse.ArgumentParser(description="Run train_cloud.py case") parser.add_argument("--device_target", type=str, default="CPU") @@ -45,6 +46,9 @@ parser.add_argument("--dp_norm_clip", type=float, default=1.0) parser.add_argument("--share_secrets_ratio", type=float, default=1.0) parser.add_argument("--cipher_time_window", type=int, default=300000) parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) +parser.add_argument("--client_password", type=str, default="") +parser.add_argument("--server_password", type=str, default="") +parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) args, _ = parser.parse_known_args() device_target = args.device_target @@ -71,6 +75,9 @@ encrypt_type = args.encrypt_type share_secrets_ratio = args.share_secrets_ratio cipher_time_window = args.cipher_time_window reconstruct_secrets_threshold = args.reconstruct_secrets_threshold +client_password = args.client_password +server_password = args.server_password +enable_ssl = args.enable_ssl if local_server_num == -1: local_server_num = server_num @@ -107,6 +114,9 @@ for i in range(local_server_num): cmd_server += " --encrypt_type=" + str(encrypt_type) cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio) cmd_server += " --cipher_time_window=" + str(cipher_time_window) + cmd_server += " --client_password=" + str(client_password) + cmd_server += " --server_password=" + str(server_password) + cmd_server += " --enable_ssl=" + str(enable_ssl) cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold) cmd_server += " > server.log 2>&1 &" diff --git a/tests/st/fl/hybrid_lenet/run_hybrid_train_sched.py b/tests/st/fl/hybrid_lenet/run_hybrid_train_sched.py index 8d09dd63791..51ae8a9dc7b 100644 --- a/tests/st/fl/hybrid_lenet/run_hybrid_train_sched.py +++ b/tests/st/fl/hybrid_lenet/run_hybrid_train_sched.py @@ -16,6 +16,7 @@ import argparse import subprocess import os +import ast parser = argparse.ArgumentParser(description="Run test_hybrid_train_lenet.py case") parser.add_argument("--device_target", type=str, default="CPU") @@ -26,6 +27,9 @@ parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1") parser.add_argument("--scheduler_port", type=int, default=8113) parser.add_argument("--scheduler_manage_port", type=int, default=11202) parser.add_argument("--config_file_path", type=str, default="") +parser.add_argument("--client_password", type=str, default="") +parser.add_argument("--server_password", type=str, default="") +parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) args, _ = parser.parse_known_args() device_target = args.device_target @@ -36,6 +40,9 @@ scheduler_ip = args.scheduler_ip scheduler_port = args.scheduler_port scheduler_manage_port = args.scheduler_manage_port config_file_path = args.config_file_path +client_password = args.client_password +server_password = args.server_password +enable_ssl = args.enable_ssl os.environ['MS_NODE_ID'] = "20" cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&" @@ -50,6 +57,9 @@ cmd_sched += " --server_num=" + str(server_num) cmd_sched += " --config_file_path=" + str(config_file_path) cmd_sched += " --scheduler_ip=" + scheduler_ip cmd_sched += " --scheduler_port=" + str(scheduler_port) +cmd_sched += " --client_password=" + str(client_password) +cmd_sched += " --server_password=" + str(server_password) +cmd_sched += " --enable_ssl=" + str(enable_ssl) cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port) cmd_sched += " > scheduler.log 2>&1 &" diff --git a/tests/st/fl/hybrid_lenet/run_hybrid_train_server.py b/tests/st/fl/hybrid_lenet/run_hybrid_train_server.py index 285acba03aa..7b9fd098269 100644 --- a/tests/st/fl/hybrid_lenet/run_hybrid_train_server.py +++ b/tests/st/fl/hybrid_lenet/run_hybrid_train_server.py @@ -16,6 +16,7 @@ import argparse import subprocess import os +import ast parser = argparse.ArgumentParser(description="Run test_hybrid_train_lenet.py case") parser.add_argument("--device_target", type=str, default="CPU") @@ -45,6 +46,9 @@ parser.add_argument("--dp_norm_clip", type=float, default=1.0) parser.add_argument("--share_secrets_ratio", type=float, default=1.0) parser.add_argument("--cipher_time_window", type=int, default=300000) parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) +parser.add_argument("--client_password", type=str, default="") +parser.add_argument("--server_password", type=str, default="") +parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) args, _ = parser.parse_known_args() device_target = args.device_target @@ -72,6 +76,9 @@ reconstruct_secrets_threshold = args.reconstruct_secrets_threshold dp_eps = args.dp_eps dp_delta = args.dp_delta dp_norm_clip = args.dp_norm_clip +client_password = args.client_password +server_password = args.server_password +enable_ssl = args.enable_ssl if local_server_num == -1: local_server_num = server_num @@ -109,6 +116,9 @@ for i in range(local_server_num): cmd_server += " --reconstruct_secrets_threshold=" + str(reconstruct_secrets_threshold) cmd_server += " --dp_eps=" + str(dp_eps) cmd_server += " --dp_delta=" + str(dp_delta) + cmd_server += " --client_password=" + str(client_password) + cmd_server += " --server_password=" + str(server_password) + cmd_server += " --enable_ssl=" + str(enable_ssl) cmd_server += " --dp_norm_clip=" + str(dp_norm_clip) cmd_server += " > server.log 2>&1 &" diff --git a/tests/st/fl/hybrid_lenet/run_hybrid_train_worker.py b/tests/st/fl/hybrid_lenet/run_hybrid_train_worker.py index 017f055b3b9..743ecfa68ed 100644 --- a/tests/st/fl/hybrid_lenet/run_hybrid_train_worker.py +++ b/tests/st/fl/hybrid_lenet/run_hybrid_train_worker.py @@ -16,6 +16,7 @@ import argparse import subprocess import os +import ast parser = argparse.ArgumentParser(description="Run test_hybrid_train_lenet.py case") parser.add_argument("--device_target", type=str, default="CPU") @@ -28,6 +29,9 @@ parser.add_argument("--fl_iteration_num", type=int, default=25) parser.add_argument("--worker_step_num_per_iteration", type=int, default=65) parser.add_argument("--local_worker_num", type=int, default=-1) parser.add_argument("--config_file_path", type=str, default="") +parser.add_argument("--client_password", type=str, default="") +parser.add_argument("--server_password", type=str, default="") +parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) args, _ = parser.parse_known_args() device_target = args.device_target @@ -40,6 +44,9 @@ fl_iteration_num = args.fl_iteration_num worker_step_num_per_iteration = args.worker_step_num_per_iteration local_worker_num = args.local_worker_num config_file_path = args.config_file_path +client_password = args.client_password +server_password = args.server_password +enable_ssl = args.enable_ssl if local_worker_num == -1: local_worker_num = worker_num @@ -61,6 +68,9 @@ for i in range(local_worker_num): cmd_worker += " --scheduler_ip=" + scheduler_ip cmd_worker += " --scheduler_port=" + str(scheduler_port) cmd_worker += " --config_file_path=" + str(config_file_path) + cmd_worker += " --client_password=" + str(client_password) + cmd_worker += " --server_password=" + str(server_password) + cmd_worker += " --enable_ssl=" + str(enable_ssl) cmd_worker += " --fl_iteration_num=" + str(fl_iteration_num) cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration) cmd_worker += " > worker.log 2>&1 &" diff --git a/tests/st/fl/hybrid_lenet/run_server_disaster_recovery.py b/tests/st/fl/hybrid_lenet/run_server_disaster_recovery.py index 5aecb8f49b8..60d65aab301 100644 --- a/tests/st/fl/hybrid_lenet/run_server_disaster_recovery.py +++ b/tests/st/fl/hybrid_lenet/run_server_disaster_recovery.py @@ -49,6 +49,9 @@ parser.add_argument("--dp_eps", type=float, default=50.0) parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold parser.add_argument("--dp_norm_clip", type=float, default=1.0) parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") +parser.add_argument("--client_password", type=str, default="") +parser.add_argument("--server_password", type=str, default="") +parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) args, _ = parser.parse_known_args() @@ -80,6 +83,9 @@ dp_eps = args.dp_eps dp_delta = args.dp_delta dp_norm_clip = args.dp_norm_clip encrypt_type = args.encrypt_type +client_password = args.client_password +server_password = args.server_password +enable_ssl = args.enable_ssl #Step 1: make the server offline. @@ -126,6 +132,9 @@ cmd_server += " --pki_verify=" + str(pki_verify) cmd_server += " --root_first_crl_path=" + str(root_first_crl_path) cmd_server += " --root_second_crl_path=" + str(root_second_crl_path) cmd_server += " --root_second_crl_path=" + str(root_second_crl_path) +cmd_server += " --client_password=" + str(client_password) +cmd_server += " --server_password=" + str(server_password) +cmd_server += " --enable_ssl=" + str(enable_ssl) cmd_server += " --sts_jar_path=" + str(sts_jar_path) cmd_server += " --sts_properties_path=" + str(sts_properties_path) cmd_server += " > server.log 2>&1 &" diff --git a/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py b/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py index c2f7e5c6c3e..ab5c8bf6fa6 100644 --- a/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py +++ b/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py @@ -14,6 +14,7 @@ # ============================================================================ import argparse +import ast import numpy as np import mindspore.context as context @@ -54,6 +55,9 @@ parser.add_argument("--dp_norm_clip", type=float, default=1.0) parser.add_argument("--share_secrets_ratio", type=float, default=1.0) parser.add_argument("--cipher_time_window", type=int, default=300000) parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) +parser.add_argument("--client_password", type=str, default="") +parser.add_argument("--server_password", type=str, default="") +parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) args, _ = parser.parse_known_args() device_target = args.device_target @@ -83,6 +87,9 @@ reconstruct_secrets_threshold = args.reconstruct_secrets_threshold dp_eps = args.dp_eps dp_delta = args.dp_delta dp_norm_clip = args.dp_norm_clip +client_password = args.client_password +server_password = args.server_password +enable_ssl = args.enable_ssl ctx = { "enable_fl": True, @@ -111,7 +118,10 @@ ctx = { "dp_eps": dp_eps, "dp_delta": dp_delta, "dp_norm_clip": dp_norm_clip, - "encrypt_type": encrypt_type + "encrypt_type": encrypt_type, + "client_password": client_password, + "server_password": server_password, + "enable_ssl": enable_ssl } context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False) diff --git a/tests/st/fl/mobile/run_mobile_sched.py b/tests/st/fl/mobile/run_mobile_sched.py index 701d469af70..5d8f257e5ef 100644 --- a/tests/st/fl/mobile/run_mobile_sched.py +++ b/tests/st/fl/mobile/run_mobile_sched.py @@ -16,6 +16,7 @@ import argparse import subprocess import os +import ast parser = argparse.ArgumentParser(description="Run test_mobile_lenet.py case") parser.add_argument("--device_target", type=str, default="CPU") @@ -27,6 +28,9 @@ parser.add_argument("--scheduler_port", type=int, default=8113) parser.add_argument("--fl_server_port", type=int, default=6666) parser.add_argument("--scheduler_manage_port", type=int, default=11202) parser.add_argument("--config_file_path", type=str, default="") +parser.add_argument("--client_password", type=str, default="") +parser.add_argument("--server_password", type=str, default="") +parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) if __name__ == "__main__": args, _ = parser.parse_known_args() @@ -39,6 +43,9 @@ if __name__ == "__main__": fl_server_port = args.fl_server_port scheduler_manage_port = args.scheduler_manage_port config_file_path = args.config_file_path + client_password = args.client_password + server_password = args.server_password + enable_ssl = args.enable_ssl os.environ['MS_NODE_ID'] = "20" cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&" @@ -54,6 +61,9 @@ if __name__ == "__main__": cmd_sched += " --scheduler_port=" + str(scheduler_port) cmd_sched += " --config_file_path=" + str(config_file_path) cmd_sched += " --fl_server_port=" + str(fl_server_port) + cmd_sched += " --client_password=" + str(client_password) + cmd_sched += " --server_password=" + str(server_password) + cmd_sched += " --enable_ssl=" + str(enable_ssl) cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port) cmd_sched += " > scheduler.log 2>&1 &" diff --git a/tests/st/fl/mobile/run_mobile_server.py b/tests/st/fl/mobile/run_mobile_server.py index 2cffdcb1df5..e52e3444a16 100644 --- a/tests/st/fl/mobile/run_mobile_server.py +++ b/tests/st/fl/mobile/run_mobile_server.py @@ -16,6 +16,7 @@ import argparse import subprocess import os +import ast parser = argparse.ArgumentParser(description="Run test_mobile_lenet.py case") parser.add_argument("--device_target", type=str, default="CPU") @@ -45,6 +46,9 @@ parser.add_argument("--dp_norm_clip", type=float, default=1.0) parser.add_argument("--share_secrets_ratio", type=float, default=1.0) parser.add_argument("--cipher_time_window", type=int, default=300000) parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) +parser.add_argument("--client_password", type=str, default="") +parser.add_argument("--server_password", type=str, default="") +parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) if __name__ == "__main__": args, _ = parser.parse_known_args() @@ -73,6 +77,9 @@ if __name__ == "__main__": dp_delta = args.dp_delta dp_norm_clip = args.dp_norm_clip encrypt_type = args.encrypt_type + client_password = args.client_password + server_password = args.server_password + enable_ssl = args.enable_ssl if local_server_num == -1: local_server_num = server_num @@ -110,6 +117,9 @@ if __name__ == "__main__": cmd_server += " --dp_eps=" + str(dp_eps) cmd_server += " --dp_delta=" + str(dp_delta) cmd_server += " --dp_norm_clip=" + str(dp_norm_clip) + cmd_server += " --client_password=" + str(client_password) + cmd_server += " --server_password=" + str(server_password) + cmd_server += " --enable_ssl=" + str(enable_ssl) cmd_server += " --encrypt_type=" + str(encrypt_type) cmd_server += " > server.log 2>&1 &" diff --git a/tests/st/fl/mobile/run_server_disaster_recovery.py b/tests/st/fl/mobile/run_server_disaster_recovery.py index 3f36004323f..2bf0f80ad20 100644 --- a/tests/st/fl/mobile/run_server_disaster_recovery.py +++ b/tests/st/fl/mobile/run_server_disaster_recovery.py @@ -48,6 +48,8 @@ parser.add_argument("--dp_eps", type=float, default=50.0) parser.add_argument("--dp_delta", type=float, default=0.01) # usually equals 1/start_fl_job_threshold parser.add_argument("--dp_norm_clip", type=float, default=1.0) parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") +parser.add_argument("--client_password", type=str, default="") +parser.add_argument("--server_password", type=str, default="") parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) @@ -79,6 +81,8 @@ dp_eps = args.dp_eps dp_delta = args.dp_delta dp_norm_clip = args.dp_norm_clip encrypt_type = args.encrypt_type +client_password = args.client_password +server_password = args.server_password enable_ssl = args.enable_ssl #Step 1: make the server offline. @@ -126,6 +130,9 @@ cmd_server += " --pki_verify=" + str(pki_verify) cmd_server += " --root_first_crl_path=" + str(root_first_crl_path) cmd_server += " --root_second_crl_path=" + str(root_second_crl_path) cmd_server += " --root_second_crl_path=" + str(root_second_crl_path) +cmd_server += " --client_password=" + str(client_password) +cmd_server += " --server_password=" + str(server_password) +cmd_server += " --enable_ssl=" + str(enable_ssl) cmd_server += " --sts_jar_path=" + str(sts_jar_path) cmd_server += " --sts_properties_path=" + str(sts_properties_path) cmd_server += " > server.log 2>&1 &" diff --git a/tests/st/fl/mobile/test_mobile_lenet.py b/tests/st/fl/mobile/test_mobile_lenet.py index 7916853fa59..42c9a6dc010 100644 --- a/tests/st/fl/mobile/test_mobile_lenet.py +++ b/tests/st/fl/mobile/test_mobile_lenet.py @@ -14,6 +14,7 @@ # ============================================================================ import argparse +import ast import numpy as np import mindspore.context as context @@ -52,6 +53,9 @@ parser.add_argument("--dp_norm_clip", type=float, default=1.0) parser.add_argument("--share_secrets_ratio", type=float, default=1.0) parser.add_argument("--cipher_time_window", type=int, default=300000) parser.add_argument("--reconstruct_secrets_threshold", type=int, default=3) +parser.add_argument("--client_password", type=str, default="") +parser.add_argument("--server_password", type=str, default="") +parser.add_argument("--enable_ssl", type=ast.literal_eval, default=False) args, _ = parser.parse_known_args() device_target = args.device_target @@ -80,6 +84,9 @@ dp_eps = args.dp_eps dp_delta = args.dp_delta dp_norm_clip = args.dp_norm_clip encrypt_type = args.encrypt_type +client_password = args.client_password +server_password = args.server_password +enable_ssl = args.enable_ssl ctx = { "enable_fl": True, @@ -107,7 +114,10 @@ ctx = { "dp_eps": dp_eps, "dp_delta": dp_delta, "dp_norm_clip": dp_norm_clip, - "encrypt_type": encrypt_type + "encrypt_type": encrypt_type, + "client_password": client_password, + "server_password": server_password, + "enable_ssl": enable_ssl } context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)