diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/CipherClient.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/CipherClient.java index d089f90dadb..f069d2200b2 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/CipherClient.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/CipherClient.java @@ -83,6 +83,7 @@ public class CipherClient { private Random random = new Random(); private ClientListReq clientListReq = new ClientListReq(); private ReconstructSecretReq reconstructSecretReq = new ReconstructSecretReq(); + private int retCode; public CipherClient(int iter, int minSecretNum, byte[] prime, int featureSize) { flCommunication = FLCommunication.getInstance(); @@ -110,6 +111,10 @@ public class CipherClient { return nextRequestTime; } + public int getRetCode() { + return retCode; + } + public void genDHKeyPairs() { byte[] csk = keyAgreement.generatePrivateKey(); byte[] cpk = keyAgreement.generatePublicKey(csk); @@ -282,13 +287,13 @@ public class CipherClient { } public FLClientStatus judgeRequestExchangeKeys(ResponseExchangeKeys bufData) { - int retcode = bufData.retcode(); + retCode = bufData.retcode(); LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestExchangeKeys**************")); - LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode)); + LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode)); LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason())); LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration())); LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime())); - switch (retcode) { + switch (retCode) { case (ResponseCode.SUCCEED): LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys success")); return FLClientStatus.SUCCESS; @@ -301,7 +306,7 @@ public class CipherClient { LOGGER.info(Common.addTag("[PairWiseMask] catch RequestError or SystemError in RequestExchangeKeys")); return FLClientStatus.FAILED; default: - LOGGER.severe(Common.addTag("[PairWiseMask] the return from server in ResponseExchangeKeys is invalid: " + retcode)); + LOGGER.severe(Common.addTag("[PairWiseMask] the return from server in ResponseExchangeKeys is invalid: " + retCode)); return FLClientStatus.FAILED; } } @@ -329,12 +334,12 @@ public class CipherClient { } public FLClientStatus judgeGetExchangeKeys(ReturnExchangeKeys bufData) { - int retcode = bufData.retcode(); + retCode = bufData.retcode(); LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetExchangeKeys**************")); - LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode)); + LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode)); LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration())); LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime())); - switch (retcode) { + switch (retCode) { case (ResponseCode.SUCCEED): LOGGER.info(Common.addTag("[PairWiseMask] GetExchangeKeys success")); clientPublicKeyList.clear(); @@ -366,7 +371,7 @@ public class CipherClient { LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetExchangeKeys")); return FLClientStatus.FAILED; default: - LOGGER.severe(Common.addTag("[PairWiseMask] the return from server in ReturnExchangeKeys is invalid: " + retcode)); + LOGGER.severe(Common.addTag("[PairWiseMask] the return from server in ReturnExchangeKeys is invalid: " + retCode)); return FLClientStatus.FAILED; } } @@ -417,13 +422,13 @@ public class CipherClient { } public FLClientStatus judgeRequestShareSecrets(ResponseShareSecrets bufData) { - int retcode = bufData.retcode(); + retCode = bufData.retcode(); LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestShareSecrets**************")); - LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode)); + LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode)); LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason())); LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration())); LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime())); - switch (retcode) { + switch (retCode) { case (ResponseCode.SUCCEED): LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets success")); return FLClientStatus.SUCCESS; @@ -436,7 +441,7 @@ public class CipherClient { LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in RequestShareSecrets")); return FLClientStatus.FAILED; default: - LOGGER.severe(Common.addTag("[PairWiseMask] the return from server in ResponseShareSecrets is invalid: " + retcode)); + LOGGER.severe(Common.addTag("[PairWiseMask] the return from server in ResponseShareSecrets is invalid: " + retCode)); return FLClientStatus.FAILED; } } @@ -464,13 +469,13 @@ public class CipherClient { } public FLClientStatus judgeGetShareSecrets(ReturnShareSecrets bufData) { - int retcode = bufData.retcode(); + retCode = bufData.retcode(); LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetShareSecrets**************")); - LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode)); + LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode)); LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration())); LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime())); LOGGER.info(Common.addTag("[PairWiseMask] the size of encrypted shares: " + bufData.encryptedSharesLength())); - switch (retcode) { + switch (retCode) { case (ResponseCode.SUCCEED): LOGGER.info(Common.addTag("[PairWiseMask] GetShareSecrets success")); returnShareList.clear(); @@ -499,7 +504,7 @@ public class CipherClient { LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetShareSecrets")); return FLClientStatus.FAILED; default: - LOGGER.severe(Common.addTag("[PairWiseMask] the return from server in ReturnShareSecrets is invalid: " + retcode)); + LOGGER.severe(Common.addTag("[PairWiseMask] the return from server in ReturnShareSecrets is invalid: " + retCode)); return FLClientStatus.FAILED; } } @@ -563,6 +568,7 @@ public class CipherClient { if (curStatus != FLClientStatus.SUCCESS) { return curStatus; } + retCode = clientListReq.getRetCode(); // SendReconstructSecret curStatus = reconstructSecretReq.sendReconstructSecret(decryptShareSecretsList, u3ClientList, iteration); @@ -573,6 +579,7 @@ public class CipherClient { if (curStatus == FLClientStatus.RESTART) { nextRequestTime = reconstructSecretReq.getNextRequestTime(); } + retCode = reconstructSecretReq.getRetCode(); return curStatus; } } \ No newline at end of file diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/Common.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/Common.java index e585cc8b67b..474f6131421 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/Common.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/Common.java @@ -88,7 +88,7 @@ public class Common { Date date = new Date(); long currentTime = date.getTime(); long waitTime = 0; - if (!nextRequestTime.equals("")) { + if (!("").equals(nextRequestTime)) { waitTime = Math.max(0, Long.valueOf(nextRequestTime) - currentTime); } LOGGER.info(addTag("[getWaitTime] next request time stamp: " + nextRequestTime + " current time stamp: " + currentTime)); @@ -114,12 +114,20 @@ public class Common { return LOG_TITLE + message; } - public static boolean isAutoscaling(byte[] message, String autoscalingTag) { - return (new String(message)).contains(autoscalingTag); + public static boolean isSafeMod(byte[] message, String safeModTag) { + return (new String(message)).contains(safeModTag); } public static boolean checkPath(String path) { - File file = new File(path); - return file.exists(); + boolean tag = true; + String [] paths = path.split(","); + for (int i = 0; i < paths.length; i++) { + LOGGER.info(addTag("[check path]:" + paths[i])); + File file = new File(paths[i]); + if (!file.exists()) { + tag = false; + } + } + return tag; } } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLCommunication.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLCommunication.java index 37e97c3dc4a..8aa65add4d1 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLCommunication.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLCommunication.java @@ -48,7 +48,7 @@ public class FLCommunication implements IFLCommunication { private static final Logger LOGGER = Logger.getLogger(FLCommunication.class.toString()); private OkHttpClient client; - private static FLCommunication communication; + private static volatile FLCommunication communication; private FLCommunication() { if (flParameter.getTimeOut() != 0) { @@ -109,14 +109,16 @@ public class FLCommunication implements IFLCommunication { } public static FLCommunication getInstance() { - if (communication == null) { + FLCommunication localRef = communication; + if (localRef == null) { synchronized (FLCommunication.class) { - if (communication == null) { - communication = new FLCommunication(); + localRef = communication; + if (localRef == null) { + communication = localRef = new FLCommunication(); } } } - return communication; + return localRef; } @Override diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java new file mode 100644 index 00000000000..e878af23d31 --- /dev/null +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java @@ -0,0 +1,535 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore.flclient; + +import com.mindspore.flclient.cipher.BaseUtil; +import com.mindspore.flclient.model.AdInferBert; +import com.mindspore.flclient.model.AdTrainBert; +import com.mindspore.flclient.model.SessionUtil; +import com.mindspore.flclient.model.TrainLenet; +import mindspore.schema.CipherPublicParams; +import mindspore.schema.FLPlan; +import mindspore.schema.ResponseCode; +import mindspore.schema.ResponseFLJob; +import mindspore.schema.ResponseGetModel; +import mindspore.schema.ResponseUpdateModel; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Date; +import java.util.HashMap; +import java.util.Map; +import java.util.TreeMap; +import java.util.logging.Logger; + +import static com.mindspore.flclient.FLParameter.SLEEP_TIME; +import static com.mindspore.flclient.LocalFLParameter.ADBERT; +import static com.mindspore.flclient.LocalFLParameter.LENET; + +public class FLLiteClient { + private static final Logger LOGGER = Logger.getLogger(FLLiteClient.class.toString()); + private FLCommunication flCommunication; + + private FLClientStatus status; + private int retCode; + + private static int iteration = 0; + private int iterations = 1; + private int epochs = 1; + private int batchSize = 16; + private int minSecretNum; + private byte[] prime; + private int featureSize; + private int trainDataSize; + private double dpEps = 100; + private double dpDelta = 0.01; + private double dpNormClip = 2.0; + + private FLParameter flParameter = FLParameter.getInstance(); + private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); + private SecureProtocol secureProtocol = new SecureProtocol(); + private static Map mapBeforeTrain; + private String nextRequestTime; + + public FLLiteClient() { + flCommunication = FLCommunication.getInstance(); + } + + public int setGlobalParameters(ResponseFLJob flJob) { + FLPlan flPlan = flJob.flPlanConfig(); + if (flPlan == null) { + LOGGER.severe(Common.addTag("[startFLJob] the FLPlan get from server is null")); + return -1; + } + iterations = flPlan.iterations(); + epochs = flPlan.epochs(); + batchSize = flPlan.miniBatch(); + String serverMod = flPlan.serverMode(); + localFLParameter.setServerMod(serverMod); + LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + serverMod)); + if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { + LOGGER.info(Common.addTag("[startFLJob] set for AdTrainBert: " + batchSize)); + AdTrainBert adTrainBert = AdTrainBert.getInstance(); + adTrainBert.setBatchSize(batchSize); + } else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) { + LOGGER.info(Common.addTag("[startFLJob] set for TrainLenet: " + batchSize)); + TrainLenet trainLenet = TrainLenet.getInstance(); + trainLenet.setBatchSize(batchSize); + } + LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + iterations)); + LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + epochs)); + LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + batchSize)); + CipherPublicParams cipherPublicParams = flPlan.cipher(); + String encryptLevel = EncryptLevel.NOT_ENCRYPT.toString(); + localFLParameter.setEncryptLevel(encryptLevel); + LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + encryptLevel)); + switch (localFLParameter.getEncryptLevel()) { + case PW_ENCRYPT: + minSecretNum = cipherPublicParams.t(); + int primeLength = cipherPublicParams.primeLength(); + prime = new byte[primeLength]; + for (int i = 0; i < primeLength; i++) { + prime[i] = (byte) cipherPublicParams.prime(i); + } + LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + minSecretNum)); + LOGGER.info(Common.addTag("[Encrypt] the prime from server: " + BaseUtil.byte2HexString(prime))); + case DP_ENCRYPT: + dpEps = cipherPublicParams.dpEps(); + dpDelta = cipherPublicParams.dpDelta(); + dpNormClip = cipherPublicParams.dpNormClip(); + LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + dpEps)); + LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + dpDelta)); + LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + dpNormClip)); + break; + default: + LOGGER.info(Common.addTag("[startFLJob] NotEncrypt, do not set parameter for Encrypt")); + } + return 0; + } + + public int getRetCode() { + return retCode; + } + + public int getIteration() { + return iteration; + } + + public int getIterations() { + return iterations; + } + + public int getEpochs() { + return epochs; + } + + public int getBatchSize() { + return batchSize; + } + + public String getNextRequestTime() { + return nextRequestTime; + } + + public void setNextRequestTime(String nextRequestTime) { + this.nextRequestTime = nextRequestTime; + } + + public void setTrainDataSize(int trainDataSize) { + this.trainDataSize = trainDataSize; + } + + public FLClientStatus checkStatus() { + return this.status; + } + + public FLClientStatus startFLJob() { + LOGGER.info(Common.addTag("[startFLJob] ====================================Verify server====================================")); + String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum()); + LOGGER.info(Common.addTag("[startFLJob] ==============startFLJob url: " + url + "==============")); + StartFLJob startFLJob = StartFLJob.getInstance(); + Date date = new Date(); + long time = date.getTime(); + byte[] msg = startFLJob.getRequestStartFLJob(trainDataSize, iteration, time); + try { + long start = Common.startTime("single startFLJob"); + LOGGER.info(Common.addTag("[startFLJob] the request message length: " + msg.length)); + byte[] message = flCommunication.syncRequest(url + "/startFLJob", msg); + if (Common.isSafeMod(message, localFLParameter.getSafeMod())) { + LOGGER.info(Common.addTag("[startFLJob] The cluster is in safemode, need wait some time and request again")); + status = FLClientStatus.RESTART; + Common.sleep(SLEEP_TIME); + nextRequestTime = ""; + return status; + } + LOGGER.info(Common.addTag("[startFLJob] the response message length: " + message.length)); + Common.endTime(start, "single startFLJob"); + ByteBuffer buffer = ByteBuffer.wrap(message); + ResponseFLJob responseDataBuf = ResponseFLJob.getRootAsResponseFLJob(buffer); + status = judgeStartFLJob(startFLJob, responseDataBuf); + retCode = responseDataBuf.retcode(); + } catch (IOException e) { + LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in StartFLJob: catch IOException: " + e.getMessage())); + status = FLClientStatus.FAILED; + retCode = ResponseCode.RequestError; + } + return status; + } + + public FLClientStatus judgeStartFLJob(StartFLJob startFLJob, ResponseFLJob responseDataBuf) { + iteration = responseDataBuf.iteration(); + FLClientStatus response = startFLJob.doResponse(responseDataBuf); + status = response; + switch (response) { + case SUCCESS: + LOGGER.info(Common.addTag("[startFLJob] startFLJob success")); + featureSize = startFLJob.getFeatureSize(); + secureProtocol.setEncryptFeatureName(startFLJob.getEncryptFeatureName()); + LOGGER.info(Common.addTag("[startFLJob] ***the feature size get in ResponseFLJob***: " + featureSize)); + int tag = setGlobalParameters(responseDataBuf); + if (tag == -1) { + LOGGER.severe(Common.addTag("[startFLJob] setGlobalParameters failed")); + status = FLClientStatus.FAILED; + } + break; + case RESTART: + FLPlan flPlan = responseDataBuf.flPlanConfig(); + iterations = flPlan.iterations(); + LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + iterations)); + nextRequestTime = responseDataBuf.nextReqTime(); + break; + case FAILED: + LOGGER.severe(Common.addTag("[startFLJob] startFLJob failed")); + break; + default: + LOGGER.severe(Common.addTag("[startFLJob] failed: the response of startFLJob is out of range ")); + status = FLClientStatus.FAILED; + } + return status; + } + + public FLClientStatus localTrain() { + LOGGER.info(Common.addTag("[train] ====================================global train epoch " + iteration + "====================================")); + status = FLClientStatus.SUCCESS; + retCode = ResponseCode.SUCCEED; + if (flParameter.getFlName().equals(ADBERT)) { + LOGGER.info(Common.addTag("[train] train in adbert")); + AdTrainBert adTrainBert = AdTrainBert.getInstance(); + int tag = adTrainBert.trainModel(flParameter.getTrainModelPath(), epochs); + if (tag == -1) { + LOGGER.severe(Common.addTag("[train] unsolved error code in ")); + status = FLClientStatus.FAILED; + retCode = ResponseCode.RequestError; + } + } else if (flParameter.getFlName().equals(LENET)) { + LOGGER.info(Common.addTag("[train] train in lenet")); + TrainLenet trainLenet = TrainLenet.getInstance(); + int tag = trainLenet.trainModel(flParameter.getTrainModelPath(), epochs); + if (tag == -1) { + LOGGER.severe(Common.addTag("[train] unsolved error code in ")); + status = FLClientStatus.FAILED; + retCode = ResponseCode.RequestError; + } + } + return status; + } + + public FLClientStatus updateModel() { + String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum()); + LOGGER.info(Common.addTag("[updateModel] ==============updateModel url: " + url + "==============")); + UpdateModel updateModelBuf = UpdateModel.getInstance(); + byte[] updateModelBuffer = updateModelBuf.getRequestUpdateFLJob(iteration, secureProtocol, trainDataSize); + if (updateModelBuf.getStatus() == FLClientStatus.FAILED) { + LOGGER.info(Common.addTag("[updateModel] catch error in build RequestUpdateFLJob")); + return FLClientStatus.FAILED; + } + try { + long start = Common.startTime("single updateModel"); + LOGGER.info(Common.addTag("[updateModel] the request message length: " + updateModelBuffer.length)); + byte[] message = flCommunication.syncRequest(url + "/updateModel", updateModelBuffer); + if (Common.isSafeMod(message, localFLParameter.getSafeMod())) { + LOGGER.info(Common.addTag("[updateModel] The cluster is in safemode, need wait some time and request again")); + status = FLClientStatus.RESTART; + Common.sleep(SLEEP_TIME); + nextRequestTime = ""; + return status; + } + LOGGER.info(Common.addTag("[updateModel] the response message length: " + message.length)); + Common.endTime(start, "single updateModel"); + ByteBuffer debugBuffer = ByteBuffer.wrap(message); + ResponseUpdateModel responseDataBuf = ResponseUpdateModel.getRootAsResponseUpdateModel(debugBuffer); + status = updateModelBuf.doResponse(responseDataBuf); + retCode = responseDataBuf.retcode(); + if (status == FLClientStatus.RESTART) { + nextRequestTime = responseDataBuf.nextReqTime(); + } + LOGGER.info(Common.addTag("[updateModel] get response from server ok!")); + } catch (IOException e) { + LOGGER.severe(Common.addTag("[updateModel] unsolved error code in updateModel: catch IOException: " + e.getMessage())); + status = FLClientStatus.FAILED; + retCode = ResponseCode.RequestError; + } + return status; + } + + public FLClientStatus getModel() { + String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum()); + LOGGER.info(Common.addTag("[getModel] ===========getModel url: " + url + "==============")); + GetModel getModelBuf = GetModel.getInstance(); + byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), iteration); + try { + long start = Common.startTime("single getModel"); + LOGGER.info(Common.addTag("[getModel] the request message length: " + buffer.length)); + byte[] message = flCommunication.syncRequest(url + "/getModel", buffer); + if (Common.isSafeMod(message, localFLParameter.getSafeMod())) { + LOGGER.info(Common.addTag("[getModel] The cluster is in safemode, need wait some time and request again")); + status = FLClientStatus.WAIT; + return status; + } + LOGGER.info(Common.addTag("[getModel] the response message length: " + message.length)); + Common.endTime(start, "single getModel"); + LOGGER.info(Common.addTag("[getModel] get model request success")); + ByteBuffer debugBuffer = ByteBuffer.wrap(message); + ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer); + status = getModelBuf.doResponse(responseDataBuf); + retCode = responseDataBuf.retcode(); + if (status == FLClientStatus.RESTART) { + nextRequestTime = responseDataBuf.timestamp(); + } + LOGGER.info(Common.addTag("[getModel] get response from server ok!")); + } catch (IOException e) { + LOGGER.severe(Common.addTag("[getModel] un sloved error code: catch IOException: " + e.getMessage())); + status = FLClientStatus.FAILED; + retCode = ResponseCode.RequestError; + } + return status; + } + + public static synchronized Map getOldMapCopy(Map map) { + if (mapBeforeTrain == null) { + Map copyMap = new TreeMap<>(); + for (String key : map.keySet()) { + float[] data = map.get(key); + int dataLen = data.length; + float[] weights = new float[dataLen]; + if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) && (key.indexOf("learning") < 0)) { + for (int j = 0; j < dataLen; j++) { + float weight = data[j]; + weights[j] = weight; + } + copyMap.put(key, weights); + } + } + mapBeforeTrain = copyMap; + } else { + for (String key : map.keySet()) { + float[] data = map.get(key); + float[] copyData = mapBeforeTrain.get(key); + int dataLen = data.length; + if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) && (key.indexOf("learning") < 0)) { + for (int j = 0; j < dataLen; j++) { + copyData[j] = data[j]; + } + } + } + } + return mapBeforeTrain; + } + + public FLClientStatus getFeatureMask() { + FLClientStatus curStatus; + switch (localFLParameter.getEncryptLevel()) { + case PW_ENCRYPT: + LOGGER.info(Common.addTag("[Encrypt] creating feature mask of <" + localFLParameter.getEncryptLevel().toString() + ">")); + secureProtocol.setPWParameter(iteration, minSecretNum, prime, featureSize); + curStatus = secureProtocol.pwCreateMask(); + if (curStatus == FLClientStatus.RESTART) { + nextRequestTime = secureProtocol.getNextRequestTime(); + } + retCode = secureProtocol.getRetCode(); + LOGGER.info(Common.addTag("[Encrypt] the response of create mask for <" + localFLParameter.getEncryptLevel().toString() + "> : " + curStatus)); + return curStatus; + case DP_ENCRYPT: + Map map = new HashMap(); + if (flParameter.getFlName().equals(ADBERT)) { + AdTrainBert adTrainBert = AdTrainBert.getInstance(); + map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession())); + } else if (flParameter.getFlName().equals(LENET)) { + TrainLenet trainLenet = TrainLenet.getInstance(); + map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); + } + Map copyMap = getOldMapCopy(map); + curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClip, copyMap); + retCode = ResponseCode.SUCCEED; + if (curStatus != FLClientStatus.SUCCESS) { + LOGGER.info(Common.addTag("---Differential privacy init failed---")); + retCode = ResponseCode.RequestError; + return FLClientStatus.FAILED; + } + LOGGER.info(Common.addTag("[Encrypt] set parameters for DPEncrypt!")); + return FLClientStatus.SUCCESS; + case NOT_ENCRYPT: + retCode = ResponseCode.SUCCEED; + LOGGER.info(Common.addTag("[Encrypt] don't mask model")); + return FLClientStatus.SUCCESS; + default: + retCode = ResponseCode.SUCCEED; + LOGGER.severe(Common.addTag("[Encrypt] The encrypt level is error, not encrypt by default")); + return FLClientStatus.SUCCESS; + } + } + + public FLClientStatus unMasking() { + FLClientStatus curStatus; + switch (localFLParameter.getEncryptLevel()) { + case PW_ENCRYPT: + curStatus = secureProtocol.pwUnmasking(); + retCode = secureProtocol.getRetCode(); + LOGGER.info(Common.addTag("[Encrypt] the response of unmasking : " + curStatus)); + if (curStatus == FLClientStatus.RESTART) { + nextRequestTime = secureProtocol.getNextRequestTime(); + } + return curStatus; + case DP_ENCRYPT: + LOGGER.info(Common.addTag("[Encrypt] DPEncrypt do not need unmasking")); + retCode = ResponseCode.SUCCEED; + return FLClientStatus.SUCCESS; + case NOT_ENCRYPT: + LOGGER.info(Common.addTag("[Encrypt] haven't mask model")); + retCode = ResponseCode.SUCCEED; + return FLClientStatus.SUCCESS; + default: + LOGGER.severe(Common.addTag("[Encrypt] The encrypt level is error, not encrypt by default")); + retCode = ResponseCode.SUCCEED; + return FLClientStatus.SUCCESS; + } + } + + public FLClientStatus evaluateModel() { + status = FLClientStatus.SUCCESS; + retCode = ResponseCode.SUCCEED; + LOGGER.info(Common.addTag("===================================test combine model from server===================================")); + if (flParameter.getFlName().equals(ADBERT)) { + AdInferBert adInferBert = AdInferBert.getInstance(); + int dataSize = adInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile(), true); + if (dataSize <= 0) { + LOGGER.severe(Common.addTag("[evaluate] unsolved error code in : the return dataSize<=0")); + status = FLClientStatus.FAILED; + retCode = ResponseCode.RequestError; + return status; + } + float acc = adInferBert.evalModel(); + if (acc == Float.NaN) { + LOGGER.severe(Common.addTag("[evaluate] unsolved error code in : the return acc is NAN")); + status = FLClientStatus.FAILED; + retCode = ResponseCode.RequestError; + return status; + } + LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " + flParameter.getTestDataset() + " vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile())); + LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc)); + } else if (flParameter.getFlName().equals(LENET)) { + TrainLenet trainLenet = TrainLenet.getInstance(); + int dataSize = trainLenet.initDataSet(flParameter.getTestDataset().split(",")[0], flParameter.getTestDataset().split(",")[1]); + if (dataSize <= 0) { + LOGGER.severe(Common.addTag("[evaluate] unsolved error code in : the return dataSize<=0")); + status = FLClientStatus.FAILED; + retCode = ResponseCode.RequestError; + return status; + } + float acc = trainLenet.evalModel(); + if (acc == Float.NaN) { + LOGGER.severe(Common.addTag("[evaluate] unsolved error code in : the return acc is NAN")); + status = FLClientStatus.FAILED; + retCode = ResponseCode.RequestError; + return status; + } + LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " + flParameter.getTestDataset().split(",")[0] + " labelPath: " + flParameter.getTestDataset().split(",")[1])); + LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc)); + } + return status; + } + + /** + * @param dataPath, train or test dataset and label set + */ + public int setInput(String dataPath) { + retCode = ResponseCode.SUCCEED; + LOGGER.info(Common.addTag("==========set input===========")); + int dataSize = 0; + if (flParameter.getFlName().equals(ADBERT)) { + AdTrainBert adTrainBert = AdTrainBert.getInstance(); + dataSize = adTrainBert.initDataSet(dataPath, flParameter.getVocabFile(), flParameter.getIdsFile()); + LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath + " dataSize: " + +dataSize + " vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile())); + } else if (flParameter.getFlName().equals(LENET)) { + TrainLenet trainLenet = TrainLenet.getInstance(); + dataSize = trainLenet.initDataSet(dataPath.split(",")[0], dataPath.split(",")[1]); + LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath.split(",")[0] + " dataSize: " + +dataSize + " labelPath: " + dataPath.split(",")[1])); + } + if (dataSize <= 0) { + retCode = ResponseCode.RequestError; + return -1; + } + return dataSize; + } + + public FLClientStatus initSession() { + int tag = 0; + retCode = ResponseCode.SUCCEED; + if (flParameter.getFlName().equals(ADBERT)) { + LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session=============")); + AdTrainBert adTrainBert = AdTrainBert.getInstance(); + tag = adTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true); + if (tag == -1) { + LOGGER.severe(Common.addTag("[initSession] unsolved error code in : the return is -1")); + retCode = ResponseCode.RequestError; + return FLClientStatus.FAILED; + } + LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session=============")); + AdInferBert adInferBert = AdInferBert.getInstance(); + tag = adInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false); + } else if (flParameter.getFlName().equals(LENET)) { + LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session=============")); + TrainLenet trainLenet = TrainLenet.getInstance(); + tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true); + } + if (tag == -1) { + LOGGER.severe(Common.addTag("[initSession] unsolved error code in : the return is -1")); + retCode = ResponseCode.RequestError; + return FLClientStatus.FAILED; + } + return FLClientStatus.SUCCESS; + } + + @Override + protected void finalize() { + if (flParameter.getFlName().equals(ADBERT)) { + LOGGER.info(Common.addTag("===========free train session=============")); + AdTrainBert adTrainBert = AdTrainBert.getInstance(); + SessionUtil.free(adTrainBert.getTrainSession()); + if (!flParameter.getTestDataset().equals("null")) { + LOGGER.info(Common.addTag("===========free inference session=============")); + AdInferBert adInferBert = AdInferBert.getInstance(); + SessionUtil.free(adInferBert.getTrainSession()); + } + } else if (flParameter.getFlName().equals(LENET)) { + LOGGER.info(Common.addTag("===========free session=============")); + TrainLenet trainLenet = TrainLenet.getInstance(); + SessionUtil.free(trainLenet.getTrainSession()); + } + } + +} diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLParameter.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLParameter.java new file mode 100644 index 00000000000..668dadad39a --- /dev/null +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLParameter.java @@ -0,0 +1,307 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mindspore.flclient; + +import java.util.logging.Logger; + +public class FLParameter { + private static final Logger LOGGER = Logger.getLogger(FLParameter.class.toString()); + + public static final int TIME_OUT = 100; + public static final int SLEEP_TIME = 1000; + private String hostName; + private String certPath; + + private String trainDataset; + private String vocabFile = "null"; + private String idsFile = "null"; + private String testDataset = "null"; + private String flName; + private String trainModelPath; + private String inferModelPath; + private String clientID; + private String ip; + private int port; + private boolean useSSL = false; + private int timeOut; + private int sleepTime; + private boolean useElb = false; + private int serverNum = 1; + + private boolean timer = true; + private int timeWindow = 6000; + private int reRequestNum = timeWindow / SLEEP_TIME + 1; + + private static volatile FLParameter flParameter; + + public static FLParameter getInstance() { + FLParameter localRef = flParameter; + if (localRef == null) { + synchronized (FLParameter.class) { + localRef = flParameter; + if (localRef == null) { + flParameter = localRef = new FLParameter(); + } + } + } + return localRef; + } + + public String getHostName() { + if ("".equals(hostName) || hostName.isEmpty()) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please set it before use")); + throw new RuntimeException(); + } + return hostName; + } + + public void setHostName(String hostName) { + this.hostName = hostName; + } + + public String getCertPath() { + if ("".equals(certPath) || certPath.isEmpty()) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please set it before use")); + throw new RuntimeException(); + } + return certPath; + } + + public void setCertPath(String certPath) { + this.certPath = certPath; + } + + public String getTrainDataset() { + if ("".equals(trainDataset) || trainDataset.isEmpty()) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please set it before use")); + throw new RuntimeException(); + } + return trainDataset; + } + + public void setTrainDataset(String trainDataset) { + if (Common.checkPath(trainDataset)) { + this.trainDataset = trainDataset; + } else { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is not exist, please check it before set")); + throw new RuntimeException(); + } + } + + public String getVocabFile() { + if ("null".equals(vocabFile) && "adbert".equals(flName)) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please set it before use")); + throw new RuntimeException(); + } + return vocabFile; + } + + public void setVocabFile(String vocabFile) { + if (Common.checkPath(vocabFile)) { + this.vocabFile = vocabFile; + } else { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is not exist, please check it before set")); + throw new RuntimeException(); + } + } + + public String getIdsFile() { + if ("null".equals(idsFile) && "adbert".equals(flName)) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please set it before use")); + throw new RuntimeException(); + } + return idsFile; + } + + public void setIdsFile(String idsFile) { + if (Common.checkPath(idsFile)) { + this.idsFile = idsFile; + } else { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is not exist, please check it before set")); + throw new RuntimeException(); + } + } + + public String getTestDataset() { + return testDataset; + } + + public void setTestDataset(String testDataset) { + if (Common.checkPath(testDataset)) { + this.testDataset = testDataset; + } else { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is not exist, please check it before set")); + throw new RuntimeException(); + } + } + + public String getFlName() { + if ("".equals(flName) || flName.isEmpty()) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please set it before use")); + throw new RuntimeException(); + } + return flName; + } + + public void setFlName(String flName) { + if (Common.checkFLName(flName)) { + this.flName = flName; + } else { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is not in flNameTrustList, please check it before set")); + throw new RuntimeException(); + } + } + + public String getTrainModelPath() { + if ("".equals(trainModelPath) || trainModelPath.isEmpty()) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please set it before use")); + throw new RuntimeException(); + } + return trainModelPath; + } + + public void setTrainModelPath(String trainModelPath) { + if (Common.checkPath(trainModelPath)) { + this.trainModelPath = trainModelPath; + } else { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is not exist, please check it before set")); + throw new RuntimeException(); + } + } + + public String getInferModelPath() { + if ("".equals(inferModelPath) || inferModelPath.isEmpty()) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please set it before use")); + throw new RuntimeException(); + } + return inferModelPath; + } + + public void setInferModelPath(String inferModelPath) { + if (Common.checkPath(inferModelPath)) { + this.inferModelPath = inferModelPath; + } else { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is not exist, please check it before set")); + throw new RuntimeException(); + } + } + + public String getIp() { + if ("".equals(ip) || ip.isEmpty()) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please set it before use")); + throw new RuntimeException(); + } + return ip; + } + + public void setIp(String ip) { + this.ip = ip; + } + + public boolean isUseSSL() { + return useSSL; + } + + public void setUseSSL(boolean useSSL) { + this.useSSL = useSSL; + } + + public int getPort() { + if (port == 0) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please set it before use")); + throw new RuntimeException(); + } + return port; + } + + public void setPort(int port) { + this.port = port; + } + + + public int getTimeOut() { + return timeOut; + } + + public void setTimeOut(int timeOut) { + this.timeOut = timeOut; + } + + public int getSleepTime() { + return sleepTime; + } + + public void setSleepTime(int sleepTime) { + this.sleepTime = sleepTime; + } + + public boolean isUseElb() { + return useElb; + } + + public void setUseElb(boolean useElb) { + this.useElb = useElb; + } + + public int getServerNum() { + if (serverNum == 0) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is zero, please set it before use")); + throw new RuntimeException(); + } + return serverNum; + } + + public void setServerNum(int serverNum) { + this.serverNum = serverNum; + } + + public boolean isTimer() { + return timer; + } + + public void setTimer(boolean timer) { + this.timer = timer; + } + + public int getTimeWindow() { + return timeWindow; + } + + public void setTimeWindow(int timeWindow) { + this.timeWindow = timeWindow; + } + + public int getReRequestNum() { + return reRequestNum; + } + + public void setReRequestNum(int reRequestNum) { + this.reRequestNum = reRequestNum; + } + + public String getClientID() { + if ("".equals(clientID) || clientID.isEmpty()) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please set it before use")); + throw new RuntimeException(); + } + return clientID; + } + + public void setClientID(String clientID) { + this.clientID = clientID; + } + +} diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/GetModel.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/GetModel.java new file mode 100644 index 00000000000..9595f9d0308 --- /dev/null +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/GetModel.java @@ -0,0 +1,184 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mindspore.flclient; + +import com.google.flatbuffers.FlatBufferBuilder; +import com.mindspore.flclient.model.AdInferBert; +import com.mindspore.flclient.model.AdTrainBert; +import com.mindspore.flclient.model.SessionUtil; +import com.mindspore.flclient.model.TrainLenet; +import mindspore.schema.FeatureMap; +import mindspore.schema.RequestGetModel; +import mindspore.schema.ResponseCode; +import mindspore.schema.ResponseGetModel; + +import java.util.ArrayList; +import java.util.Date; +import java.util.logging.Logger; + +public class GetModel { + static { + System.loadLibrary("mindspore-lite-jni"); + } + + class RequestGetModelBuilder { + private FlatBufferBuilder builder; + private int nameOffset = 0; + private int iteration = 0; + private int timeStampOffset = 0; + + public RequestGetModelBuilder() { + builder = new FlatBufferBuilder(); + } + + public RequestGetModelBuilder flName(String name) { + this.nameOffset = this.builder.createString(name); + return this; + } + + public RequestGetModelBuilder time() { + Date date = new Date(); + long time = date.getTime(); + this.timeStampOffset = builder.createString(String.valueOf(time)); + return this; + } + + public RequestGetModelBuilder iteration(int iteration) { + this.iteration = iteration; + return this; + } + + public byte[] build() { + int root = RequestGetModel.createRequestGetModel(this.builder, nameOffset, iteration, timeStampOffset); + builder.finish(root); + return builder.sizedByteArray(); + } + } + + private static final Logger LOGGER = Logger.getLogger(GetModel.class.toString()); + private static GetModel getModel; + + private GetModel() { + } + + private FLParameter flParameter = FLParameter.getInstance(); + private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); + + public static GetModel getInstance() { + if (getModel == null) { + getModel = new GetModel(); + } + return getModel; + } + + public byte[] getRequestGetModel(String name, int iteration) { + RequestGetModelBuilder builder = new RequestGetModelBuilder(); + return builder.iteration(iteration).flName(name).time().build(); + } + + + private FLClientStatus parseResponseAdbert(ResponseGetModel responseDataBuf) { + int fmCount = responseDataBuf.featureMapLength(); + ArrayList albertFeatureMaps = new ArrayList(); + ArrayList inferFeatureMaps = new ArrayList(); + for (int i = 0; i < fmCount; i++) { + FeatureMap feature = responseDataBuf.featureMap(i); + String featureName = feature.weightFullname(); + if (localFLParameter.getAlbertWeightName().contains(featureName)) { + albertFeatureMaps.add(feature); + inferFeatureMaps.add(feature); + } else if (localFLParameter.getClassifierWeightName().contains(featureName)) { + inferFeatureMaps.add(feature); + } else { + continue; + } + LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength())); + } + int tag = 0; + LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference model-----------------")); + AdInferBert adInferBert = AdInferBert.getInstance(); + tag = SessionUtil.updateFeatures(adInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps); + if (tag == -1) { + LOGGER.severe(Common.addTag("[getModel] unsolved error code in ")); + return FLClientStatus.FAILED; + } + LOGGER.info(Common.addTag("[getModel] ----------------loading weight into train model-----------------")); + AdTrainBert adTrainBert = AdTrainBert.getInstance(); + tag = SessionUtil.updateFeatures(adTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps); + if (tag == -1) { + LOGGER.severe(Common.addTag("[getModel] unsolved error code in ")); + return FLClientStatus.FAILED; + } + return FLClientStatus.SUCCESS; + } + + private FLClientStatus parseResponseLenet(ResponseGetModel responseDataBuf) { + int fmCount = responseDataBuf.featureMapLength(); + ArrayList featureMaps = new ArrayList(); + for (int i = 0; i < fmCount; i++) { + FeatureMap feature = responseDataBuf.featureMap(i); + String featureName = feature.weightFullname(); + featureMaps.add(feature); + LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " + feature.dataLength())); + } + int tag = 0; + LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------")); + TrainLenet trainLenet = TrainLenet.getInstance(); + tag = SessionUtil.updateFeatures(trainLenet.getTrainSession(), flParameter.getTrainModelPath(), featureMaps); + if (tag == -1) { + LOGGER.severe(Common.addTag("[getModel] unsolved error code in ")); + return FLClientStatus.FAILED; + } + return FLClientStatus.SUCCESS; + } + + + public FLClientStatus doResponse(ResponseGetModel responseDataBuf) { + LOGGER.info(Common.addTag("[getModel] ==========get model content is:================")); + LOGGER.info(Common.addTag("[getModel] ==========retCode: " + responseDataBuf.retcode())); + LOGGER.info(Common.addTag("[getModel] ==========reason: " + responseDataBuf.reason())); + LOGGER.info(Common.addTag("[getModel] ==========iteration: " + responseDataBuf.iteration())); + LOGGER.info(Common.addTag("[getModel] ==========time: " + responseDataBuf.timestamp())); + FLClientStatus status = FLClientStatus.SUCCESS; + int retCode = responseDataBuf.retcode(); + switch (retCode) { + case (ResponseCode.SUCCEED): + LOGGER.info(Common.addTag("[getModel] getModel response success")); + if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { + LOGGER.info(Common.addTag("[getModel] into ")); + status = parseResponseAdbert(responseDataBuf); + } else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) { + LOGGER.info(Common.addTag("[getModel] into ")); + status = parseResponseLenet(responseDataBuf); + } + return status; + case (ResponseCode.SucNotReady): + LOGGER.info(Common.addTag("[getModel] server is not ready now: need wait and request getModel again")); + return FLClientStatus.WAIT; + case (ResponseCode.OutOfTime): + LOGGER.info(Common.addTag("[getModel] out of time: need wait and request startFLJob again")); + return FLClientStatus.RESTART; + case (ResponseCode.RequestError): + case (ResponseCode.SystemError): + LOGGER.warning(Common.addTag("[getModel] catch RequestError or SystemError")); + return FLClientStatus.FAILED; + default: + LOGGER.severe(Common.addTag("[getModel] the return from server is invalid: " + retCode)); + return FLClientStatus.FAILED; + } + } + +} diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java new file mode 100644 index 00000000000..fdc32429678 --- /dev/null +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java @@ -0,0 +1,125 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mindspore.flclient; + +import java.util.ArrayList; +import java.util.List; +import java.util.logging.Logger; + +public class LocalFLParameter { + private static final Logger LOGGER = Logger.getLogger(LocalFLParameter.class.toString()); + public static final int SEED_SIZE = 32; + public static final int IVEC_LEN = 16; + public static final String LENET = "lenet"; + public static final String ADBERT = "adbert"; + private List classifierWeightName = new ArrayList<>(); + private List albertWeightName = new ArrayList<>(); + + private String flID; + private String encryptLevel = "NotEncrypt"; + private String earlyStopMod = "NotEarlyStop"; + private String serverMod = ServerMod.HYBRID_TRAINING.toString(); + private String safeMod = "The cluster is in safemode."; + + private static volatile LocalFLParameter localFLParameter; + + private LocalFLParameter() { + // set classifierWeightName albertWeightName + Common.setClassifierWeightName(classifierWeightName); + Common.setAlbertWeightName(albertWeightName); + } + + public static synchronized LocalFLParameter getInstance() { + LocalFLParameter localRef = localFLParameter; + if (localRef == null) { + synchronized (LocalFLParameter.class) { + localRef = localFLParameter; + if (localRef == null) { + localFLParameter = localRef = new LocalFLParameter(); + } + } + } + return localRef; + } + + public List getClassifierWeightName() { + if (classifierWeightName.isEmpty()) { + LOGGER.severe(Common.addTag("[localFLParameter] the parameter of is null, please set it before use")); + throw new RuntimeException(); + } + return classifierWeightName; + } + + public void setClassifierWeightName(List classifierWeightName) { + this.classifierWeightName = classifierWeightName; + } + + public List getAlbertWeightName() { + if (albertWeightName.isEmpty()) { + LOGGER.severe(Common.addTag("[localFLParameter] the parameter of is null, please set it before use")); + throw new RuntimeException(); + } + return albertWeightName; + } + + public void setAlbertWeightName(List albertWeightName) { + this.albertWeightName = albertWeightName; + } + + public String getFlID() { + if ("".equals(flID) || flID == null) { + LOGGER.severe(Common.addTag("[localFLParameter] the parameter of is null, please set it before use")); + throw new RuntimeException(); + } + return flID; + } + + public void setFlID(String flID) { + this.flID = flID; + } + + public EncryptLevel getEncryptLevel() { + return EncryptLevel.valueOf(encryptLevel); + } + + public void setEncryptLevel(String encryptLevel) { + this.encryptLevel = encryptLevel; + } + + public EarlyStopMod getEarlyStopMod() { + return EarlyStopMod.valueOf(earlyStopMod); + } + + public void setEarlyStopMod(String earlyStopMod) { + this.earlyStopMod = earlyStopMod; + } + + public String getServerMod() { + return serverMod; + } + + public void setServerMod(String serverMod) { + this.serverMod = serverMod; + } + + public String getSafeMod() { + return safeMod; + } + + public void setSafeMod(String safeMod) { + this.safeMod = safeMod; + } +} diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java index 9a270d38957..a1c668463ce 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java @@ -42,6 +42,7 @@ public class SecureProtocol { private static double deltaError = 1e-6; private static Map modelMap; private ArrayList encryptFeatureName = new ArrayList(); + private int retCode; public FLClientStatus getStatus() { return status; @@ -51,6 +52,10 @@ public class SecureProtocol { return featureMask; } + public int getRetCode() { + return retCode; + } + public SecureProtocol() { } @@ -91,6 +96,7 @@ public class SecureProtocol { LOGGER.info("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() + "=============="); // round 0 status = cipher.exchangeKeys(); + retCode = cipher.getRetCode(); LOGGER.info("[PairWiseMask] ============= RequestExchangeKeys+GetExchangeKeys response: " + status + "============"); if (status != FLClientStatus.SUCCESS) { return status; @@ -98,6 +104,7 @@ public class SecureProtocol { // round 1 try { status = cipher.shareSecrets(); + retCode = cipher.getRetCode(); LOGGER.info("[Encrypt] =============RequestShareSecrets+GetShareSecrets response: " + status + "============="); } catch (Exception e) { LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask"); @@ -109,6 +116,7 @@ public class SecureProtocol { // round2 try { featureMask = cipher.doubleMaskingWeight(); + retCode = cipher.getRetCode(); LOGGER.info("[Encrypt] =============Create double feature mask: SUCCESS============="); } catch (Exception e) { LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask"); @@ -151,6 +159,7 @@ public class SecureProtocol { public FLClientStatus pwUnmasking() { status = cipher.reconstructSecrets(); // round3 + retCode = cipher.getRetCode(); LOGGER.info("[Encrypt] =============GetClientList+SendReconstructSecret: " + status + "============="); return status; } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/StartFLJob.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/StartFLJob.java new file mode 100644 index 00000000000..232259d6fe5 --- /dev/null +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/StartFLJob.java @@ -0,0 +1,230 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mindspore.flclient; + +import com.google.flatbuffers.FlatBufferBuilder; +import com.mindspore.flclient.model.AdInferBert; +import com.mindspore.flclient.model.AdTrainBert; +import com.mindspore.flclient.model.SessionUtil; +import com.mindspore.flclient.model.TrainLenet; +import mindspore.schema.FeatureMap; +import mindspore.schema.RequestFLJob; +import mindspore.schema.ResponseCode; +import mindspore.schema.ResponseFLJob; + +import java.util.ArrayList; +import java.util.logging.Logger; + +public class StartFLJob { + static { + System.loadLibrary("mindspore-lite-jni"); + } + + private static final Logger LOGGER = Logger.getLogger(StartFLJob.class.toString()); + + class RequestStartFLJobBuilder { + private RequestFLJob requestFLJob; + private FlatBufferBuilder builder; + private int nameOffset = 0; + private int iteration = 0; + private int dataSize = 0; + private int timestampOffset = 0; + private int idOffset = 0; + + public RequestStartFLJobBuilder() { + builder = new FlatBufferBuilder(); + } + + public RequestStartFLJobBuilder flName(String name) { + this.nameOffset = this.builder.createString(name); + return this; + } + + public RequestStartFLJobBuilder id(String id) { + this.idOffset = this.builder.createString(id); + return this; + } + + public RequestStartFLJobBuilder time(long timestamp) { + this.timestampOffset = builder.createString(String.valueOf(timestamp)); + return this; + } + + public RequestStartFLJobBuilder dataSize(int dataSize) { + // temp code need confirm + this.dataSize = dataSize; + LOGGER.info(Common.addTag("[startFLJob] the train data size: " + dataSize)); + return this; + } + + public RequestStartFLJobBuilder iteration(int iteration) { + this.iteration = iteration; + return this; + } + + public byte[] build() { + int root = RequestFLJob.createRequestFLJob(this.builder, this.nameOffset, this.idOffset, this.iteration, + this.dataSize, this.timestampOffset); + builder.finish(root); + return builder.sizedByteArray(); + } + } + + private static StartFLJob startFLJob; + + private FLClientStatus status; + + private FLParameter flParameter = FLParameter.getInstance(); + private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); + private int featureSize; + private String nextRequestTime; + private ArrayList encryptFeatureName = new ArrayList(); + + private StartFLJob() { + + } + + public static StartFLJob getInstance() { + if (startFLJob == null) { + startFLJob = new StartFLJob(); + } + return startFLJob; + } + + public String getNextRequestTime() { + return nextRequestTime; + } + + public byte[] getRequestStartFLJob(int dataSize, int iteration, long time) { + RequestStartFLJobBuilder builder = new RequestStartFLJobBuilder(); + return builder.flName(flParameter.getFlName()) + .time(time) + .id(localFLParameter.getFlID()) + .dataSize(dataSize) + .iteration(iteration) + .build(); + } + + public int getFeatureSize() { + return featureSize; + } + + public ArrayList getEncryptFeatureName() { + return encryptFeatureName; + } + + private FLClientStatus parseResponseAdbert(ResponseFLJob flJob) { + int fmCount = flJob.featureMapLength(); + ArrayList albertFeatureMaps = new ArrayList(); + ArrayList inferFeatureMaps = new ArrayList(); + encryptFeatureName.clear(); + if (fmCount <= 0) { + LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero")); + return FLClientStatus.FAILED; + } + for (int i = 0; i < fmCount; i++) { + FeatureMap feature = flJob.featureMap(i); + String featureName = feature.weightFullname(); + if (localFLParameter.getAlbertWeightName().contains(featureName)) { + albertFeatureMaps.add(feature); + inferFeatureMaps.add(feature); + featureSize += feature.dataLength(); + encryptFeatureName.add(feature.weightFullname()); + } else if (localFLParameter.getClassifierWeightName().contains(featureName)) { + inferFeatureMaps.add(feature); + } else { + continue; + } + LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength())); + } + int tag = 0; + LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference model-----------------")); + AdInferBert adInferBert = AdInferBert.getInstance(); + tag = SessionUtil.updateFeatures(adInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps); + if (tag == -1) { + LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in ")); + return FLClientStatus.FAILED; + } + LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------")); + AdTrainBert adTrainBert = AdTrainBert.getInstance(); + tag = SessionUtil.updateFeatures(adTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps); + if (tag == -1) { + LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in ")); + return FLClientStatus.FAILED; + } + return FLClientStatus.SUCCESS; + } + + private FLClientStatus parseResponseLenet(ResponseFLJob flJob) { + int fmCount = flJob.featureMapLength(); + ArrayList featureMaps = new ArrayList(); + encryptFeatureName.clear(); + for (int i = 0; i < fmCount; i++) { + FeatureMap feature = flJob.featureMap(i); + String featureName = feature.weightFullname(); + featureMaps.add(feature); + featureSize += feature.dataLength(); + encryptFeatureName.add(featureName); + LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength())); + } + int tag = 0; + LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------")); + TrainLenet trainLenet = TrainLenet.getInstance(); + tag = SessionUtil.updateFeatures(trainLenet.getTrainSession(), flParameter.getTrainModelPath(), featureMaps); + if (tag == -1) { + LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in ")); + return FLClientStatus.FAILED; + } + return FLClientStatus.SUCCESS; + } + + public FLClientStatus doResponse(ResponseFLJob flJob) { + LOGGER.info(Common.addTag("[startFLJob] return code: " + flJob.retcode())); + LOGGER.info(Common.addTag("[startFLJob] reason: " + flJob.reason())); + LOGGER.info(Common.addTag("[startFLJob] iteration: " + flJob.iteration())); + LOGGER.info(Common.addTag("[startFLJob] is selected: " + flJob.isSelected())); + LOGGER.info(Common.addTag("[startFLJob] next request time: " + flJob.nextReqTime())); + nextRequestTime = flJob.nextReqTime(); + LOGGER.info(Common.addTag("[startFLJob] timestamp: " + flJob.timestamp())); + int retcode = flJob.retcode(); + + switch (retcode) { + case (ResponseCode.SUCCEED): + localFLParameter.setServerMod(flJob.flPlanConfig().serverMode()); + if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { + LOGGER.info(Common.addTag("[startFLJob] into ")); + parseResponseAdbert(flJob); + } else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) { + LOGGER.info(Common.addTag("[startFLJob] into ")); + parseResponseLenet(flJob); + } + return FLClientStatus.SUCCESS; + case (ResponseCode.OutOfTime): + return FLClientStatus.RESTART; + case (ResponseCode.RequestError): + case (ResponseCode.SystemError): + LOGGER.info(Common.addTag("[startFLJob] catch RequestError or SystemError")); + return FLClientStatus.FAILED; + default: + LOGGER.severe(Common.addTag("[startFLJob] the return from server is invalid: " + retcode)); + return FLClientStatus.FAILED; + } + } + + public FLClientStatus getStatus() { + return this.status; + } +} diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java new file mode 100644 index 00000000000..cb91b0a3012 --- /dev/null +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java @@ -0,0 +1,322 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mindspore.flclient; + +import com.mindspore.flclient.model.AdInferBert; +import com.mindspore.flclient.model.AdTrainBert; +import com.mindspore.flclient.model.SessionUtil; +import com.mindspore.flclient.model.TrainLenet; +import mindspore.schema.ResponseGetModel; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.logging.Logger; + +import static com.mindspore.flclient.FLParameter.SLEEP_TIME; +import static com.mindspore.flclient.LocalFLParameter.ADBERT; +import static com.mindspore.flclient.LocalFLParameter.LENET; + +public class SyncFLJob { + private static final Logger LOGGER = Logger.getLogger(SyncFLJob.class.toString()); + private FLParameter flParameter = FLParameter.getInstance(); + private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); + private FLJobResultCallback flJobResultCallback = new FLJobResultCallback(); + + public SyncFLJob() { + } + + public void flJobRun() { + localFLParameter.setFlID(flParameter.getClientID()); + FLLiteClient client = new FLLiteClient(); + client.initSession(); + + FLClientStatus curStatus; + do { + LOGGER.info(Common.addTag("flName: " + flParameter.getFlName())); + int trainDataSize = client.setInput(flParameter.getTrainDataset()); + LOGGER.info(Common.addTag("train path: " + flParameter.getTrainDataset())); + LOGGER.info(Common.addTag("train data size: " + trainDataSize)); + if (trainDataSize <= 0) { + LOGGER.severe(Common.addTag("unsolved error code in : the return trainDataSize<=0")); + curStatus = FLClientStatus.FAILED; + flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(), client.getRetCode()); + break; + } + client.setTrainDataSize(trainDataSize); + + // startFLJob + curStatus = client.startFLJob(); + while (curStatus == FLClientStatus.WAIT) { + waitSomeTime(); + curStatus = client.startFLJob(); + } + if (curStatus == FLClientStatus.RESTART) { + restart("[startFLJob]", client.getNextRequestTime(), client.getIteration(), client.getRetCode()); + continue; + } else if (curStatus == FLClientStatus.FAILED) { + failed("[startFLJob]", client.getIteration(), client.getRetCode(), curStatus); + break; + } + LOGGER.info(Common.addTag("[startFLJob] startFLJob succeed, curIteration: " + client.getIteration())); + + // create mask + curStatus = client.getFeatureMask(); + if (curStatus == FLClientStatus.RESTART) { + restart("[Encrypt] creatMask", client.getNextRequestTime(), client.getIteration(), client.getRetCode()); + continue; + } else if (curStatus == FLClientStatus.FAILED) { + failed("[Encrypt] createMask", client.getIteration(), client.getRetCode(), curStatus); + break; + } + + // train + curStatus = client.localTrain(); + if (curStatus == FLClientStatus.FAILED) { + failed("[train] train", client.getIteration(), client.getRetCode(), curStatus); + break; + } + LOGGER.info(Common.addTag("[train] train succeed")); + + // updateModel + curStatus = client.updateModel(); + while (curStatus == FLClientStatus.WAIT) { + waitSomeTime(); + curStatus = client.updateModel(); + } + if (curStatus == FLClientStatus.RESTART) { + restart("[updateModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode()); + continue; + } else if (curStatus == FLClientStatus.FAILED) { + failed("[updateModel] updateModel", client.getIteration(), client.getRetCode(), curStatus); + break; + } + LOGGER.info(Common.addTag("[updateModel] updateModel succeed")); + + // unmasking + curStatus = client.unMasking(); + if (curStatus == FLClientStatus.RESTART) { + restart("[Encrypt] unmasking", client.getNextRequestTime(), client.getIteration(), client.getRetCode()); + continue; + } else if (curStatus == FLClientStatus.FAILED) { + failed("[Encrypt] unmasking", client.getIteration(), client.getRetCode(), curStatus); + break; + } + + // getModel + curStatus = client.getModel(); + while (curStatus == FLClientStatus.WAIT) { + waitSomeTime(); + curStatus = client.getModel(); + } + if (curStatus == FLClientStatus.RESTART) { + restart("[getModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode()); + continue; + } else if (curStatus == FLClientStatus.FAILED) { + failed("[getModel] getModel", client.getIteration(), client.getRetCode(), curStatus); + break; + } + LOGGER.info(Common.addTag("[getModel] getModel succeed")); + + //evaluate model after getting model from server + if (flParameter.getTestDataset().equals("null")) { + LOGGER.info(Common.addTag("[evaluate] the testDataset is null, don't evaluate the combine model")); + } else { + curStatus = client.evaluateModel(); + if (curStatus == FLClientStatus.FAILED) { + failed("[evaluate] evaluate", client.getIteration(), client.getRetCode(), curStatus); + break; + } + LOGGER.info(Common.addTag("[evaluate] evaluate succeed")); + } + LOGGER.info(Common.addTag("========================================================the total response of " + client.getIteration() + ": " + curStatus + "======================================================================")); + flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(), client.getRetCode()); + } while (client.getIteration() < client.getIterations()); + client.finalize(); + LOGGER.info(Common.addTag("flJobRun finish")); + flJobResultCallback.onFlJobFinished(flParameter.getFlName(), client.getIterations(), client.getRetCode()); + } + + public int[] modelInference(String flName, String dataPath, String vocabFile, String idsFile, String modelPath) { + int[] labels = new int[0]; + if (flName.equals(ADBERT)) { + AdInferBert adInferBert = AdInferBert.getInstance(); + LOGGER.info(Common.addTag("===========model inference=============")); + labels = adInferBert.inferModel(modelPath, dataPath, vocabFile, idsFile); + LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels))); + SessionUtil.free(adInferBert.getTrainSession()); + LOGGER.info(Common.addTag("[model inference] inference finish")); + } else if (flName.equals(LENET)) { + TrainLenet trainLenet = TrainLenet.getInstance(); + LOGGER.info(Common.addTag("===========model inference=============")); + labels = trainLenet.inferModel(modelPath, dataPath.split(",")[0]); + LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels))); + SessionUtil.free(trainLenet.getTrainSession()); + LOGGER.info(Common.addTag("[model inference] inference finish")); + } + if (labels.length == 0) { + LOGGER.severe(Common.addTag("[model inference] the return labels is null.")); + } + return labels; + } + + public FLClientStatus getModel(boolean useElb, int serverNum, String ip, int port, String flName, String trainModelPath, String inferModelPath, boolean useSSL) { + int tag = 0; + flParameter.setTrainModelPath(trainModelPath); + flParameter.setInferModelPath(inferModelPath); + FLClientStatus status = FLClientStatus.SUCCESS; + try { + if (flName.equals(ADBERT)) { + localFLParameter.setServerMod(ServerMod.HYBRID_TRAINING.toString()); + LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + trainModelPath + " Create Train Session=============")); + AdTrainBert adTrainBert = AdTrainBert.getInstance(); + tag = adTrainBert.initSessionAndInputs(trainModelPath, true); + if (tag == -1) { + LOGGER.severe(Common.addTag("[initSession] unsolved error code in : the return is -1")); + return FLClientStatus.FAILED; + } + LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " + inferModelPath + " Create inference Session=============")); + AdInferBert adInferBert = AdInferBert.getInstance(); + tag = adInferBert.initSessionAndInputs(inferModelPath, false); + } else if (flName.equals(LENET)) { + localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString()); + LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + trainModelPath + " Create Train Session=============")); + TrainLenet trainLenet = TrainLenet.getInstance(); + tag = trainLenet.initSessionAndInputs(trainModelPath, true); + } + if (tag == -1) { + LOGGER.severe(Common.addTag("[initSession] unsolved error code in : the return is -1")); + return FLClientStatus.FAILED; + } + flParameter.setUseSSL(useSSL); + FLCommunication flCommunication = FLCommunication.getInstance(); + String url = Common.generateUrl(useElb, ip, port, serverNum); + LOGGER.info(Common.addTag("[getModel] ===========getModel url: " + url + "==============")); + GetModel getModelBuf = GetModel.getInstance(); + byte[] buffer = getModelBuf.getRequestGetModel(flName, 0); + byte[] message = flCommunication.syncRequest(url + "/getModel", buffer); + LOGGER.info(Common.addTag("[getModel] get model request success")); + ByteBuffer debugBuffer = ByteBuffer.wrap(message); + ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer); + status = getModelBuf.doResponse(responseDataBuf); + LOGGER.info(Common.addTag("[getModel] success!")); + } catch (Exception e) { + LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + e.getMessage())); + status = FLClientStatus.FAILED; + } + if (flName.equals(ADBERT)) { + LOGGER.info(Common.addTag("===========free train session=============")); + AdTrainBert adTrainBert = AdTrainBert.getInstance(); + SessionUtil.free(adTrainBert.getTrainSession()); + LOGGER.info(Common.addTag("===========free inference session=============")); + AdInferBert adInferBert = AdInferBert.getInstance(); + SessionUtil.free(adInferBert.getTrainSession()); + } else if (flName.equals(LENET)) { + LOGGER.info(Common.addTag("===========free session=============")); + TrainLenet trainLenet = TrainLenet.getInstance(); + SessionUtil.free(trainLenet.getTrainSession()); + } + return status; + } + + private void waitSomeTime() { + if (flParameter.getSleepTime() != 0) + Common.sleep(flParameter.getSleepTime()); + else + Common.sleep(SLEEP_TIME); + } + + private void waitNextReqTime(String nextReqTime) { + if (flParameter.isTimer()) { + long waitTime = Common.getWaitTime(nextReqTime); + Common.sleep(waitTime); + } else { + waitSomeTime(); + } + } + + private void restart(String tag, String nextReqTime, int iteration, int retcode) { + LOGGER.info(Common.addTag(tag + " out of time: need wait and request startFLJob again")); + waitNextReqTime(nextReqTime); + flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), iteration, retcode); + } + + private void failed(String tag, int iteration, int retcode, FLClientStatus curStatus) { + LOGGER.info(Common.addTag(tag + " failed")); + LOGGER.info(Common.addTag("========================================================the total response of " + iteration + ": " + curStatus + "======================================================================")); + flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), iteration, retcode); + } + + public static void main(String[] args) { + String trainDataset = args[0]; + String vocabFile = args[1]; + String idsFile = args[2]; + String testDataset = args[3]; + String flName = args[4]; + String trainModelPath = args[5]; + String inferModelPath = args[6]; + String clientID = args[7]; + String ip = args[8]; + boolean useSSL = Boolean.parseBoolean(args[9]); + int port = Integer.parseInt(args[10]); + int timeWindow = Integer.parseInt(args[11]); + boolean useElb = Boolean.parseBoolean(args[12]); + int serverNum = Integer.parseInt(args[13]); + String task = args[14]; + FLParameter flParameter = FLParameter.getInstance(); + LOGGER.info(Common.addTag("[args] trainDataset: " + trainDataset)); + LOGGER.info(Common.addTag("[args] vocabFile: " + vocabFile)); + LOGGER.info(Common.addTag("[args] idsFile: " + idsFile)); + LOGGER.info(Common.addTag("[args] testDataset: " + testDataset)); + LOGGER.info(Common.addTag("[args] flName: " + flName)); + LOGGER.info(Common.addTag("[args] trainModelPath: " + trainModelPath)); + LOGGER.info(Common.addTag("[args] inferModelPath: " + inferModelPath)); + LOGGER.info(Common.addTag("[args] clientID: " + clientID)); + LOGGER.info(Common.addTag("[args] ip: " + ip)); + LOGGER.info(Common.addTag("[args] useSSL: " + useSSL)); + LOGGER.info(Common.addTag("[args] port: " + port)); + LOGGER.info(Common.addTag("[args] timeWindow: " + timeWindow)); + LOGGER.info(Common.addTag("[args] useElb: " + useElb)); + LOGGER.info(Common.addTag("[args] serverNum: " + serverNum)); + LOGGER.info(Common.addTag("[args] task: " + task)); + + flParameter.setClientID(clientID); + SyncFLJob syncFLJob = new SyncFLJob(); + if (task.equals("train")) { + flParameter.setTrainDataset(trainDataset); + flParameter.setFlName(flName); + flParameter.setTrainModelPath(trainModelPath); + flParameter.setTestDataset(testDataset); + flParameter.setInferModelPath(inferModelPath); + flParameter.setIp(ip); + flParameter.setUseSSL(useSSL); + flParameter.setPort(port); + flParameter.setTimeWindow(timeWindow); + flParameter.setUseElb(useElb); + flParameter.setServerNum(serverNum); + if (ADBERT.equals(flName)) { + flParameter.setVocabFile(vocabFile); + flParameter.setIdsFile(idsFile); + } + syncFLJob.flJobRun(); + } else if (task.equals("inference")) { + syncFLJob.modelInference(flName, testDataset, vocabFile, idsFile, inferModelPath); + } else if (task.equals("getModel")) { + syncFLJob.getModel(false, 1, ip, port, flName, trainModelPath, inferModelPath, false); + } else { + LOGGER.info(Common.addTag("do not do any thing!")); + } + } +} diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java new file mode 100644 index 00000000000..58e99c444a0 --- /dev/null +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java @@ -0,0 +1,210 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore.flclient; + +import com.google.flatbuffers.FlatBufferBuilder; +import com.mindspore.flclient.model.AdTrainBert; +import com.mindspore.flclient.model.SessionUtil; +import com.mindspore.flclient.model.TrainLenet; +import mindspore.schema.FeatureMap; +import mindspore.schema.RequestUpdateModel; +import mindspore.schema.ResponseCode; +import mindspore.schema.ResponseUpdateModel; + +import java.util.ArrayList; +import java.util.Date; +import java.util.HashMap; +import java.util.Map; +import java.util.logging.Logger; + +import static com.mindspore.flclient.LocalFLParameter.ADBERT; +import static com.mindspore.flclient.LocalFLParameter.LENET; + +public class UpdateModel { + static { + System.loadLibrary("mindspore-lite-jni"); + } + + class RequestUpdateModelBuilder { + private RequestUpdateModel requestUM; + private FlatBufferBuilder builder; + private int fmOffset = 0; + private int nameOffset = 0; + private int idOffset = 0; + private int timestampOffset = 0; + private int iteration = 0; + private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT; + + public RequestUpdateModelBuilder(EncryptLevel encryptLevel) { + builder = new FlatBufferBuilder(); + this.encryptLevel = encryptLevel; + } + + public RequestUpdateModelBuilder flName(String name) { + this.nameOffset = this.builder.createString(name); + return this; + } + + public RequestUpdateModelBuilder time() { + Date date = new Date(); + long time = date.getTime(); + this.timestampOffset = builder.createString(String.valueOf(time)); + return this; + } + + public RequestUpdateModelBuilder iteration(int iteration) { + this.iteration = iteration; + return this; + } + + public RequestUpdateModelBuilder id(String id) { + this.idOffset = this.builder.createString(id); + return this; + } + + public RequestUpdateModelBuilder featuresMap(SecureProtocol secureProtocol, int trainDataSize) { + ArrayList encryptFeatureName = secureProtocol.getEncryptFeatureName(); + switch (encryptLevel) { + case PW_ENCRYPT: + try { + int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize); + this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW); + LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!")); + return this; + } catch (Exception e) { + LOGGER.severe("[Encrypt] catch error in maskModel: catch Exception"); + } + case DP_ENCRYPT: + try { + int[] fmOffsetsPW = secureProtocol.dpMaskModel(builder, trainDataSize); + this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW); + LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!")); + return this; + } catch (Exception e) { + LOGGER.severe(Common.addTag("[Encrypt] catch error in maskModel: " + e.getMessage())); + } + case NOT_ENCRYPT: + default: + Map map = new HashMap(); + if (flParameter.getFlName().equals(ADBERT)) { + LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + flParameter.getFlName())); + AdTrainBert adTrainBert = AdTrainBert.getInstance(); + map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession())); + if (map.isEmpty()) { + LOGGER.severe(Common.addTag("[updateModel] the return map is empty in ")); + status = FLClientStatus.FAILED; + } + } else if (flParameter.getFlName().equals(LENET)) { + LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + flParameter.getFlName())); + TrainLenet trainLenet = TrainLenet.getInstance(); + map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); + if (map.isEmpty()) { + LOGGER.severe(Common.addTag("[updateModel] the return map is empty in ")); + status = FLClientStatus.FAILED; + } + } + int featureSize = encryptFeatureName.size(); + int[] fmOffsets = new int[featureSize]; + for (int i = 0; i < featureSize; i++) { + String key = encryptFeatureName.get(i); + float[] data = map.get(key); + LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature size: " + data.length)); + for (int j = 0; j < data.length; j++) { + float rawData = data[j]; + data[j] = data[j] * trainDataSize; + } + int featureName = builder.createString(key); + int weight = FeatureMap.createDataVector(builder, data); + int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight); + fmOffsets[i] = featureMap; + } + this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsets); + return this; + } + } + + public byte[] build() { + RequestUpdateModel.startRequestUpdateModel(this.builder); + RequestUpdateModel.addFlName(builder, nameOffset); + RequestUpdateModel.addFlId(this.builder, idOffset); + RequestUpdateModel.addTimestamp(builder, this.timestampOffset); + RequestUpdateModel.addIteration(builder, this.iteration); + RequestUpdateModel.addFeatureMap(builder, this.fmOffset); + int root = RequestUpdateModel.endRequestUpdateModel(builder); + builder.finish(root); + return builder.sizedByteArray(); + } + } + + private static final Logger LOGGER = Logger.getLogger(UpdateModel.class.toString()); + private FLParameter flParameter = FLParameter.getInstance(); + private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); + private String nextRequestTime; + private FLClientStatus status; + private static volatile UpdateModel updateModel; + + private UpdateModel() { + } + + public static synchronized UpdateModel getInstance() { + UpdateModel localRef = updateModel; + if (localRef == null) { + synchronized (UpdateModel.class) { + localRef = updateModel; + if (localRef == null) { + updateModel = localRef = new UpdateModel(); + } + } + } + return localRef; + } + + public String getNextRequestTime() { + return nextRequestTime; + } + + public FLClientStatus getStatus() { + return status; + } + + public byte[] getRequestUpdateFLJob(int iteration, SecureProtocol secureProtocol, int trainDataSize) { + RequestUpdateModelBuilder builder = new RequestUpdateModelBuilder(localFLParameter.getEncryptLevel()); + return builder.flName(flParameter.getFlName()).time().id(localFLParameter.getFlID()).featuresMap(secureProtocol, trainDataSize).iteration(iteration).build(); + } + + public FLClientStatus doResponse(ResponseUpdateModel response) { + LOGGER.info(Common.addTag("[updateModel] ==========updateModel response================")); + LOGGER.info(Common.addTag("[updateModel] ==========retcode: " + response.retcode())); + LOGGER.info(Common.addTag("[updateModel] ==========reason: " + response.reason())); + LOGGER.info(Common.addTag("[updateModel] ==========next request time: " + response.nextReqTime())); + nextRequestTime = response.nextReqTime(); + switch (response.retcode()) { + case (ResponseCode.SUCCEED): + LOGGER.info(Common.addTag("[updateModel] updateModel success")); + return FLClientStatus.SUCCESS; + case (ResponseCode.OutOfTime): + return FLClientStatus.RESTART; + case (ResponseCode.RequestError): + case (ResponseCode.SystemError): + LOGGER.warning(Common.addTag("[updateModel] catch RequestError or SystemError")); + return FLClientStatus.FAILED; + default: + LOGGER.severe(Common.addTag("[updateModel]the return from server is invalid: " + response.retcode())); + return FLClientStatus.FAILED; + } + } +} diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/ClientListReq.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/ClientListReq.java index c9255e5caab..55e27265306 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/ClientListReq.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/ClientListReq.java @@ -45,6 +45,7 @@ public class ClientListReq { private String nextRequestTime; private FLParameter flParameter = FLParameter.getInstance(); private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); + private int retCode; public ClientListReq() { flCommunication = FLCommunication.getInstance(); @@ -58,6 +59,10 @@ public class ClientListReq { this.nextRequestTime = nextRequestTime; } + public int getRetCode() { + return retCode; + } + public FLClientStatus getClientList(int iteration, List u3ClientList, List decryptSecretsList, List returnShareList, Map cuvKeys) { String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum()); LOGGER.info(Common.addTag("[PairWiseMask] ==============getClientList url: " + url + "==============")); @@ -81,15 +86,15 @@ public class ClientListReq { } public FLClientStatus judgeGetClientList(ReturnClientList bufData, List u3ClientList, List decryptSecretsList, List returnShareList, Map cuvKeys) { - int retcode = bufData.retcode(); + retCode = bufData.retcode(); LOGGER.info(Common.addTag("[PairWiseMask] ************** the response of GetClientList **************")); - LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode)); + LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode)); LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason())); LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration())); LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime())); LOGGER.info(Common.addTag("[PairWiseMask] the size of clients: " + bufData.clientsLength())); FLClientStatus status; - switch (retcode) { + switch (retCode) { case (ResponseCode.SUCCEED): LOGGER.info(Common.addTag("[PairWiseMask] GetClientList success")); u3ClientList.clear(); @@ -117,7 +122,7 @@ public class ClientListReq { LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetClientList")); return FLClientStatus.FAILED; default: - LOGGER.severe(Common.addTag("[PairWiseMask] the return from server in ReturnClientList is invalid: " + retcode)); + LOGGER.severe(Common.addTag("[PairWiseMask] the return from server in ReturnClientList is invalid: " + retCode)); return FLClientStatus.FAILED; } } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/ReconstructSecretReq.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/ReconstructSecretReq.java index e12ab36f465..044be955bc1 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/ReconstructSecretReq.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/ReconstructSecretReq.java @@ -37,6 +37,7 @@ public class ReconstructSecretReq { private String nextRequestTime; private FLParameter flParameter = FLParameter.getInstance(); private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); + private int retCode; public String getNextRequestTime() { return nextRequestTime; @@ -46,6 +47,10 @@ public class ReconstructSecretReq { this.nextRequestTime = nextRequestTime; } + public int getRetCode() { + return retCode; + } + public ReconstructSecretReq() { flCommunication = FLCommunication.getInstance(); } @@ -99,13 +104,13 @@ public class ReconstructSecretReq { } public FLClientStatus judgeSendReconstructSecrets(mindspore.schema.ReconstructSecret bufData) { - int retcode = bufData.retcode(); + retCode = bufData.retcode(); LOGGER.info(Common.addTag("[PairWiseMask] **************the response of SendReconstructSecrets**************")); - LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode)); + LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode)); LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason())); LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration())); LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime())); - switch (retcode) { + switch (retCode) { case (ResponseCode.SUCCEED): LOGGER.info(Common.addTag("[PairWiseMask] ReconstructSecrets success")); return FLClientStatus.SUCCESS; @@ -118,7 +123,7 @@ public class ReconstructSecretReq { LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in SendReconstructSecrets")); return FLClientStatus.FAILED; default: - LOGGER.severe(Common.addTag("[PairWiseMask] the return from server in ReconstructSecret is invalid: " + retcode)); + LOGGER.severe(Common.addTag("[PairWiseMask] the return from server in ReconstructSecret is invalid: " + retCode)); return FLClientStatus.FAILED; } } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AdTrainBert.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AdTrainBert.java index 8dfc5f3ba53..2851d8a49e9 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AdTrainBert.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AdTrainBert.java @@ -26,7 +26,7 @@ public class AdTrainBert extends AdBert { private static volatile AdTrainBert adTrainBert; public static AdTrainBert getInstance() { - AdTrainBert localRef = adInferBert; + AdTrainBert localRef = adTrainBert; if (localRef == null) { synchronized (AdTrainBert.class) { localRef = adTrainBert;