From b4791b9f8984c00d4ce28582fcfd089a3e947ebf Mon Sep 17 00:00:00 2001 From: jin-xiulang Date: Wed, 7 Jul 2021 09:15:40 +0800 Subject: [PATCH] Optimize test scripts of mindspore federated --- .../src/main/java/com/mindspore/flclient/FLLiteClient.java | 2 +- tests/st/fl/mobile/run_mobile_server.py | 2 +- tests/st/fl/mobile/test_mobile_lenet.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java index 0c0ddc8ace2..f7291ebd2ba 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java @@ -95,7 +95,7 @@ public class FLLiteClient { LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + epochs)); LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + batchSize)); CipherPublicParams cipherPublicParams = flPlan.cipher(); - String encryptLevel = EncryptLevel.NOT_ENCRYPT.toString(); + String encryptLevel = cipherPublicParams.encryptType(); localFLParameter.setEncryptLevel(encryptLevel); LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + encryptLevel)); switch (localFLParameter.getEncryptLevel()) { diff --git a/tests/st/fl/mobile/run_mobile_server.py b/tests/st/fl/mobile/run_mobile_server.py index 4caf9caa04b..2b5bdbce6b5 100644 --- a/tests/st/fl/mobile/run_mobile_server.py +++ b/tests/st/fl/mobile/run_mobile_server.py @@ -38,7 +38,7 @@ parser.add_argument("--client_batch_size", type=int, default=32) parser.add_argument("--client_learning_rate", type=float, default=0.1) parser.add_argument("--local_server_num", type=int, default=-1) parser.add_argument("--config_file_path", type=str, default="") -parser.add_argument("--encrypt_type", type=str, default="NotEncrypt") +parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") # parameters for encrypt_type='DP_ENCRYPT' parser.add_argument("--dp_eps", type=float, default=50.0) parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num diff --git a/tests/st/fl/mobile/test_mobile_lenet.py b/tests/st/fl/mobile/test_mobile_lenet.py index b9d150c7ba3..345f52dd13b 100644 --- a/tests/st/fl/mobile/test_mobile_lenet.py +++ b/tests/st/fl/mobile/test_mobile_lenet.py @@ -46,6 +46,7 @@ parser.add_argument("--client_batch_size", type=int, default=32) parser.add_argument("--client_learning_rate", type=float, default=0.1) parser.add_argument("--scheduler_manage_port", type=int, default=11202) parser.add_argument("--config_file_path", type=str, default="") +parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") # parameters for encrypt_type='DP_ENCRYPT' parser.add_argument("--dp_eps", type=float, default=50.0) parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num