!29952 memory optimization for flclient in master

Merge pull request !29952 from zhoushan33/flclient0211_master
This commit is contained in:
i-robot 2022-02-14 04:14:12 +00:00 committed by Gitee
commit d6fd709963
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 199 additions and 61 deletions

View File

@ -16,6 +16,7 @@
package com.mindspore.flclient;
import static com.mindspore.flclient.FLParameter.MAX_WAIT_TRY_TIME;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.I_VEC_LEN;
import static com.mindspore.flclient.LocalFLParameter.SALT_SIZE;
@ -105,6 +106,7 @@ public class CipherClient {
private ReconstructSecretReq reconstructSecretReq = new ReconstructSecretReq();
private int retCode;
private Map<String, X509Certificate[]> certificateList = new HashMap<String, X509Certificate[]>();
private int waitTryTime = 0;
/**
* Construct function of cipherClient
@ -507,8 +509,7 @@ public class CipherClient {
LOGGER.info(Common.addTag("[requestExchangeKeys] the server is not ready now, need wait some time and" +
" " +
"request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
nextRequestTime = Common.getNextReqTime();
retCode = ResponseCode.OutOfTime;
return FLClientStatus.RESTART;
}
@ -589,8 +590,7 @@ public class CipherClient {
if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[getExchangeKeys] the server is not ready now, need wait some time and " +
"request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
nextRequestTime = Common.getNextReqTime();
retCode = ResponseCode.OutOfTime;
return FLClientStatus.RESTART;
}
@ -840,8 +840,7 @@ public class CipherClient {
if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[requestShareSecrets] the server is not ready now, need wait some time" +
" and request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
nextRequestTime = Common.getNextReqTime();
retCode = ResponseCode.OutOfTime;
return FLClientStatus.RESTART;
}
@ -920,8 +919,7 @@ public class CipherClient {
if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[getShareSecrets] the server is not ready now, need wait some time and " +
"request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
nextRequestTime = Common.getNextReqTime();
retCode = ResponseCode.OutOfTime;
return FLClientStatus.RESTART;
}
@ -993,7 +991,13 @@ public class CipherClient {
// RequestExchangeKeys
FLClientStatus curStatus;
curStatus = requestExchangeKeys();
waitTryTime = 0;
while (curStatus == FLClientStatus.WAIT) {
waitTryTime += 1;
if (ifWaitTryTimeExceedsLimit()) {
curStatus = FLClientStatus.FAILED;
break;
}
Common.sleep(SLEEP_TIME);
curStatus = requestExchangeKeys();
}
@ -1003,7 +1007,13 @@ public class CipherClient {
// GetExchangeKeys
curStatus = getExchangeKeys();
waitTryTime = 0;
while (curStatus == FLClientStatus.WAIT) {
waitTryTime += 1;
if (ifWaitTryTimeExceedsLimit()) {
curStatus = FLClientStatus.FAILED;
break;
}
Common.sleep(SLEEP_TIME);
curStatus = getExchangeKeys();
}
@ -1021,7 +1031,13 @@ public class CipherClient {
FLClientStatus curStatus;
// RequestShareSecrets
curStatus = requestShareSecrets();
waitTryTime = 0;
while (curStatus == FLClientStatus.WAIT) {
waitTryTime += 1;
if (ifWaitTryTimeExceedsLimit()) {
curStatus = FLClientStatus.FAILED;
break;
}
Common.sleep(SLEEP_TIME);
curStatus = requestShareSecrets();
}
@ -1031,7 +1047,13 @@ public class CipherClient {
// GetShareSecrets
curStatus = getShareSecrets();
waitTryTime = 0;
while (curStatus == FLClientStatus.WAIT) {
waitTryTime += 1;
if (ifWaitTryTimeExceedsLimit()) {
curStatus = FLClientStatus.FAILED;
break;
}
Common.sleep(SLEEP_TIME);
curStatus = getShareSecrets();
}
@ -1050,7 +1072,13 @@ public class CipherClient {
// GetClientList
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList,
cUVKeys);
waitTryTime = 0;
while (curStatus == FLClientStatus.WAIT) {
waitTryTime += 1;
if (ifWaitTryTimeExceedsLimit()) {
curStatus = FLClientStatus.FAILED;
break;
}
Common.sleep(SLEEP_TIME);
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList
, cUVKeys);
@ -1067,7 +1095,13 @@ public class CipherClient {
if (flParameter.isPkiVerify()) {
LOGGER.info(Common.addTag("[PairWiseMask] The mode is pkiVerify mode, start clientList check ..."));
curStatus = clientListCheck();
waitTryTime = 0;
while (curStatus == FLClientStatus.WAIT) {
waitTryTime += 1;
if (ifWaitTryTimeExceedsLimit()) {
curStatus = FLClientStatus.FAILED;
break;
}
Common.sleep(SLEEP_TIME);
curStatus = clientListCheck();
}
@ -1078,7 +1112,13 @@ public class CipherClient {
// SendReconstructSecret
curStatus = reconstructSecretReq.sendReconstructSecret(decryptShareSecretsList, u3ClientList, iteration);
waitTryTime = 0;
while (curStatus == FLClientStatus.WAIT) {
waitTryTime += 1;
if (ifWaitTryTimeExceedsLimit()) {
curStatus = FLClientStatus.FAILED;
break;
}
Common.sleep(SLEEP_TIME);
curStatus = reconstructSecretReq.sendReconstructSecret(decryptShareSecretsList, u3ClientList, iteration);
}
@ -1149,7 +1189,13 @@ public class CipherClient {
// send signed clientList
curStatus = sendClientListSign();
waitTryTime = 0;
while (curStatus == FLClientStatus.WAIT) {
waitTryTime += 1;
if (ifWaitTryTimeExceedsLimit()) {
curStatus = FLClientStatus.FAILED;
break;
}
Common.sleep(SLEEP_TIME);
curStatus = sendClientListSign();
}
@ -1160,7 +1206,13 @@ public class CipherClient {
// get signed clientList
curStatus = getAllClientListSign();
waitTryTime = 0;
while (curStatus == FLClientStatus.WAIT) {
waitTryTime += 1;
if (ifWaitTryTimeExceedsLimit()) {
curStatus = FLClientStatus.FAILED;
break;
}
Common.sleep(SLEEP_TIME);
curStatus = getAllClientListSign();
}
@ -1216,8 +1268,7 @@ public class CipherClient {
if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[sendClientListSign] the server is not ready now, need wait some time and " +
"request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
nextRequestTime = Common.getNextReqTime();
retCode = ResponseCode.OutOfTime;
return FLClientStatus.RESTART;
}
@ -1287,8 +1338,7 @@ public class CipherClient {
if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[getAllClientListSign] the server is not ready now, need wait some time " +
"and request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
nextRequestTime = Common.getNextReqTime();
retCode = ResponseCode.OutOfTime;
return FLClientStatus.RESTART;
}
@ -1424,4 +1474,13 @@ public class CipherClient {
}
return pemCerts;
}
private Boolean ifWaitTryTimeExceedsLimit() {
if (waitTryTime > MAX_WAIT_TRY_TIME) {
LOGGER.severe(Common.addTag("[ifWaitTryTimeExceedsLimit] the waitTryTime exceeds the limit, current " +
"waitTryTime is: " + waitTryTime + " the limited time is: " + MAX_WAIT_TRY_TIME));
return true;
}
return false;
}
}

View File

@ -41,6 +41,9 @@ import java.util.logging.Logger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static com.mindspore.flclient.FLParameter.MAX_SLEEP_TIME;
import static com.mindspore.flclient.FLParameter.MAX_WAIT_TRY_TIME;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
@ -218,14 +221,21 @@ public class Common {
* @param millis the waiting time (ms).
*/
public static void sleep(long millis) {
if (millis > 0) {
try {
synchronized (STOP_OBJECT) {
STOP_OBJECT.wait(millis); // 1000 milliseconds is one second.
}
} catch (InterruptedException ex) {
LOGGER.severe(addTag("[sleep] catch InterruptedException: " + ex.getMessage()));
if (millis <= 0) {
LOGGER.severe(addTag("[sleep] the millis is not valid(<= 0), will not do any thing, and stop the task"));
throw new IllegalArgumentException();
}
if (millis > MAX_SLEEP_TIME) {
LOGGER.severe(addTag("[sleep] the sleep time: " + millis + " exceed MAX_SLEEP_TIME: " + MAX_SLEEP_TIME
+ "(unit: ms), will only sleep 30 minutes."));
millis = MAX_SLEEP_TIME;
}
try {
synchronized (STOP_OBJECT) {
STOP_OBJECT.wait(millis); // 1000 milliseconds is one second.
}
} catch (InterruptedException ex) {
LOGGER.severe(addTag("[sleep] catch InterruptedException: " + ex.getMessage()));
}
}
@ -249,7 +259,12 @@ public class Common {
long currentTime = date.getTime();
long waitTime = 0L;
if (!(nextRequestTime == null || nextRequestTime.isEmpty())) {
waitTime = Math.max(0, Long.valueOf(nextRequestTime) - currentTime);
waitTime = Long.valueOf(nextRequestTime) - currentTime;
}
if (waitTime <= 0L) {
LOGGER.severe(addTag("[getWaitTime] waitTime: " + waitTime + " is not valid (should be > 0), the reasons " +
"may be: the nextRequestTime <= currentTime, or the nextRequestTime is null, will stop the task"));
throw new IllegalArgumentException();
}
LOGGER.info(addTag("[getWaitTime] next request time stamp: " + nextRequestTime + " current time stamp: " +
currentTime));
@ -598,4 +613,23 @@ public class Common {
SessionUtil.free(trainLenet.getTrainSession());
}
}
/**
* Initialization session.
*
* @return the status code in client.
*/
public static String getNextReqTime() {
FLParameter flParameter = FLParameter.getInstance();
long millis;
if (flParameter.getSleepTime() != 0) {
millis = flParameter.getSleepTime();
} else {
millis = SLEEP_TIME;
}
Date curDate = new Date();
long currentTime = curDate.getTime();
String nextRequestTime = Long.toString(currentTime + millis);
return nextRequestTime;
}
}

View File

@ -268,8 +268,7 @@ public class FLLiteClient {
LOGGER.info(Common.addTag("[startFLJob] the server is not ready now, need wait some time and request " +
"again"));
status = FLClientStatus.RESTART;
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
nextRequestTime = Common.getNextReqTime();
retCode = ResponseCode.OutOfTime;
return status;
}
@ -387,8 +386,7 @@ public class FLLiteClient {
LOGGER.info(Common.addTag("[updateModel] the server is not ready now, need wait some time and request" +
" again"));
status = FLClientStatus.RESTART;
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
nextRequestTime = Common.getNextReqTime();
retCode = ResponseCode.OutOfTime;
return status;
}

View File

@ -40,14 +40,30 @@ public class FLParameter {
private static final Logger LOGGER = Logger.getLogger(FLParameter.class.toString());
/**
* The timeout interval for communication on the device.
* The timeout interval for communication on the device, time unit: seconds.
*/
public static final int TIME_OUT = 100;
/**
* The waiting time of repeated requests.
* The waiting time of repeated requests, time unit: milliseconds.
*/
public static final int SLEEP_TIME = 1000;
public static final int SLEEP_TIME = 10000;
/**
* The max waiting time when call sleeping, time unit: milliseconds.
*/
public static final int MAX_SLEEP_TIME = 1800000;
/**
* Maximum number of times to repeat RESTART
*/
public static final int RESTART_TIME_PER_ITER = 1;
/**
* Maximum number of times to wait some time and then repeat the same request.
*/
public static final int MAX_WAIT_TRY_TIME = 18;
private static volatile FLParameter flParameter;
private String deployEnv;

View File

@ -16,6 +16,8 @@
package com.mindspore.flclient;
import static com.mindspore.flclient.FLParameter.MAX_WAIT_TRY_TIME;
import static com.mindspore.flclient.FLParameter.RESTART_TIME_PER_ITER;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.ANDROID;
@ -56,6 +58,9 @@ public class SyncFLJob {
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private IFLJobResultCallback flJobResultCallback;
private FLClientStatus curStatus;
private int tryTimePerIter = 0;
private int lastIteration = -1;
private int waitTryTime = 0;
private void initFlIDForPkiVerify() {
if (flParameter.isPkiVerify()) {
@ -80,7 +85,8 @@ public class SyncFLJob {
LOGGER.info(Common.addTag("the flName: " + flParameter.getFlName()));
Class.forName(flParameter.getFlName());
} catch (ClassNotFoundException e) {
LOGGER.severe(Common.addTag("catch ClassNotFoundException error, the set flName does not exist, please " +
LOGGER.severe(Common.addTag("catch ClassNotFoundException error, the set flName does not exist, " +
"please " +
"check: " + e.getMessage()));
throw new IllegalArgumentException();
}
@ -104,8 +110,18 @@ public class SyncFLJob {
FLLiteClient flLiteClient = new FLLiteClient();
LOGGER.info(Common.addTag("recovery StopJobFlag to false in the start of fl job"));
localFLParameter.setStopJobFlag(false);
InitialParameters();
LOGGER.info(Common.addTag("flJobRun start"));
flRunLoop(flLiteClient);
LOGGER.info(Common.addTag("flJobRun finish"));
flJobResultCallback.onFlJobFinished(flParameter.getFlName(), flLiteClient.getIterations(),
flLiteClient.getRetCode());
return curStatus;
}
private void flRunLoop(FLLiteClient flLiteClient) {
do {
if (checkStopJobFlag()) {
if (ifTryTimeExceedsLimit() || checkStopJobFlag()) {
break;
}
LOGGER.info(Common.addTag("flName: " + flParameter.getFlName()));
@ -119,8 +135,9 @@ public class SyncFLJob {
flLiteClient.setTrainDataSize(trainDataSize);
// startFLJob
curStatus = startFLJob(flLiteClient);
curStatus = flLiteClient.startFLJob();
if (curStatus == FLClientStatus.RESTART) {
tryTimePerIter += 1;
restart("[startFLJob]", flLiteClient.getNextRequestTime(), flLiteClient);
continue;
} else if (curStatus == FLClientStatus.FAILED) {
@ -128,6 +145,7 @@ public class SyncFLJob {
break;
}
LOGGER.info(Common.addTag("[startFLJob] startFLJob succeed, curIteration: " + flLiteClient.getIteration()));
updateTryTimePerIter(flLiteClient);
// create mask
curStatus = flLiteClient.getFeatureMask();
@ -148,7 +166,7 @@ public class SyncFLJob {
LOGGER.info(Common.addTag("[train] train succeed"));
// updateModel
curStatus = updateModel(flLiteClient);
curStatus = flLiteClient.updateModel();
if (curStatus == FLClientStatus.RESTART) {
restart("[updateModel]", flLiteClient.getNextRequestTime(), flLiteClient);
continue;
@ -199,41 +217,54 @@ public class SyncFLJob {
flLiteClient.getRetCode());
Common.freeSession();
} while (flLiteClient.getIteration() < flLiteClient.getIterations());
LOGGER.info(Common.addTag("flJobRun finish"));
flJobResultCallback.onFlJobFinished(flParameter.getFlName(), flLiteClient.getIterations(),
flLiteClient.getRetCode());
return curStatus;
}
private FLClientStatus startFLJob(FLLiteClient flLiteClient) {
FLClientStatus curStatus = flLiteClient.startFLJob();
while (curStatus == FLClientStatus.WAIT) {
if (checkStopJobFlag()) {
curStatus = FLClientStatus.FAILED;
break;
}
waitSomeTime();
curStatus = flLiteClient.startFLJob();
}
return curStatus;
private void InitialParameters() {
tryTimePerIter = 0;
lastIteration = -1;
waitTryTime = 0;
}
private FLClientStatus updateModel(FLLiteClient flLiteClient) {
FLClientStatus curStatus = flLiteClient.updateModel();
while (curStatus == FLClientStatus.WAIT) {
if (checkStopJobFlag()) {
curStatus = FLClientStatus.FAILED;
break;
}
waitSomeTime();
curStatus = flLiteClient.updateModel();
private Boolean ifTryTimeExceedsLimit() {
if (tryTimePerIter > RESTART_TIME_PER_ITER) {
LOGGER.severe(Common.addTag("[ifTryTimeExceedsLimit] the repeated request time exceeds the limit, current" +
" repeated" +
" request time is: " + tryTimePerIter + " the limited time is: " + RESTART_TIME_PER_ITER));
curStatus = FLClientStatus.FAILED;
return true;
}
return curStatus;
return false;
}
private void updateTryTimePerIter(FLLiteClient flLiteClient) {
if (lastIteration != -1 && lastIteration == flLiteClient.getIteration()) {
tryTimePerIter += 1;
} else {
tryTimePerIter = 1;
lastIteration = flLiteClient.getIteration();
}
}
private Boolean ifWaitTryTimeExceedsLimit() {
if (waitTryTime > MAX_WAIT_TRY_TIME) {
LOGGER.severe(Common.addTag("[ifWaitTryTimeExceedsLimit] the waitTryTime exceeds the limit, current " +
"waitTryTime is: " + waitTryTime + " the limited time is: " + MAX_WAIT_TRY_TIME));
curStatus = FLClientStatus.FAILED;
return true;
}
return false;
}
private FLClientStatus getModel(FLLiteClient flLiteClient) {
FLClientStatus curStatus = flLiteClient.getModel();
waitTryTime = 0;
while (curStatus == FLClientStatus.WAIT) {
waitTryTime += 1;
if (ifWaitTryTimeExceedsLimit()) {
curStatus = FLClientStatus.FAILED;
break;
}
if (checkStopJobFlag()) {
curStatus = FLClientStatus.FAILED;
break;
@ -241,6 +272,7 @@ public class SyncFLJob {
waitSomeTime();
curStatus = flLiteClient.getModel();
}
waitTryTime = 0;
return curStatus;
}
@ -303,7 +335,8 @@ public class SyncFLJob {
client.free();
return null;
}
Status tag = client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig(), flParameter.getInputShape());
Status tag = client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig(),
flParameter.getInputShape());
if (!Status.SUCCESS.equals(tag)) {
LOGGER.severe(Common.addTag("[model inference] unsolved error code in <initSessionAndInputs>: the return " +
" status is: " + tag));

View File

@ -122,8 +122,7 @@ public class ClientListReq {
if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[getClientList] the server is not ready now, need wait some time and " +
"request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
nextRequestTime = Common.getNextReqTime();
retCode = ResponseCode.OutOfTime;
return FLClientStatus.RESTART;
}

View File

@ -137,8 +137,7 @@ public class ReconstructSecretReq {
if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[sendReconstructSecret] the server is not ready now, need wait some " +
"time and request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
nextRequestTime = Common.getNextReqTime();
retCode = ResponseCode.OutOfTime;
return FLClientStatus.RESTART;
}