!18797 add code of interactive rounds of startFLJob/updateModel/getModel for fl_client

Merge pull request !18797 from zhoushan33/flclient0624_om
This commit is contained in:
i-robot 2021-06-26 03:27:08 +00:00 committed by Gitee
commit 6e6fb37498
14 changed files with 1984 additions and 35 deletions

View File

@ -83,6 +83,7 @@ public class CipherClient {
private Random random = new Random();
private ClientListReq clientListReq = new ClientListReq();
private ReconstructSecretReq reconstructSecretReq = new ReconstructSecretReq();
private int retCode;
public CipherClient(int iter, int minSecretNum, byte[] prime, int featureSize) {
flCommunication = FLCommunication.getInstance();
@ -110,6 +111,10 @@ public class CipherClient {
return nextRequestTime;
}
public int getRetCode() {
return retCode;
}
public void genDHKeyPairs() {
byte[] csk = keyAgreement.generatePrivateKey();
byte[] cpk = keyAgreement.generatePublicKey(csk);
@ -282,13 +287,13 @@ public class CipherClient {
}
public FLClientStatus judgeRequestExchangeKeys(ResponseExchangeKeys bufData) {
int retcode = bufData.retcode();
retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestExchangeKeys**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
switch (retcode) {
switch (retCode) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys success"));
return FLClientStatus.SUCCESS;
@ -301,7 +306,7 @@ public class CipherClient {
LOGGER.info(Common.addTag("[PairWiseMask] catch RequestError or SystemError in RequestExchangeKeys"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ResponseExchangeKeys is invalid: " + retcode));
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseExchangeKeys is invalid: " + retCode));
return FLClientStatus.FAILED;
}
}
@ -329,12 +334,12 @@ public class CipherClient {
}
public FLClientStatus judgeGetExchangeKeys(ReturnExchangeKeys bufData) {
int retcode = bufData.retcode();
retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetExchangeKeys**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
switch (retcode) {
switch (retCode) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[PairWiseMask] GetExchangeKeys success"));
clientPublicKeyList.clear();
@ -366,7 +371,7 @@ public class CipherClient {
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetExchangeKeys"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReturnExchangeKeys is invalid: " + retcode));
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnExchangeKeys is invalid: " + retCode));
return FLClientStatus.FAILED;
}
}
@ -417,13 +422,13 @@ public class CipherClient {
}
public FLClientStatus judgeRequestShareSecrets(ResponseShareSecrets bufData) {
int retcode = bufData.retcode();
retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestShareSecrets**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
switch (retcode) {
switch (retCode) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets success"));
return FLClientStatus.SUCCESS;
@ -436,7 +441,7 @@ public class CipherClient {
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in RequestShareSecrets"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ResponseShareSecrets is invalid: " + retcode));
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseShareSecrets is invalid: " + retCode));
return FLClientStatus.FAILED;
}
}
@ -464,13 +469,13 @@ public class CipherClient {
}
public FLClientStatus judgeGetShareSecrets(ReturnShareSecrets bufData) {
int retcode = bufData.retcode();
retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetShareSecrets**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
LOGGER.info(Common.addTag("[PairWiseMask] the size of encrypted shares: " + bufData.encryptedSharesLength()));
switch (retcode) {
switch (retCode) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[PairWiseMask] GetShareSecrets success"));
returnShareList.clear();
@ -499,7 +504,7 @@ public class CipherClient {
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetShareSecrets"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReturnShareSecrets is invalid: " + retcode));
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnShareSecrets is invalid: " + retCode));
return FLClientStatus.FAILED;
}
}
@ -563,6 +568,7 @@ public class CipherClient {
if (curStatus != FLClientStatus.SUCCESS) {
return curStatus;
}
retCode = clientListReq.getRetCode();
// SendReconstructSecret
curStatus = reconstructSecretReq.sendReconstructSecret(decryptShareSecretsList, u3ClientList, iteration);
@ -573,6 +579,7 @@ public class CipherClient {
if (curStatus == FLClientStatus.RESTART) {
nextRequestTime = reconstructSecretReq.getNextRequestTime();
}
retCode = reconstructSecretReq.getRetCode();
return curStatus;
}
}

View File

@ -88,7 +88,7 @@ public class Common {
Date date = new Date();
long currentTime = date.getTime();
long waitTime = 0;
if (!nextRequestTime.equals("")) {
if (!("").equals(nextRequestTime)) {
waitTime = Math.max(0, Long.valueOf(nextRequestTime) - currentTime);
}
LOGGER.info(addTag("[getWaitTime] next request time stamp: " + nextRequestTime + " current time stamp: " + currentTime));
@ -114,12 +114,20 @@ public class Common {
return LOG_TITLE + message;
}
public static boolean isAutoscaling(byte[] message, String autoscalingTag) {
return (new String(message)).contains(autoscalingTag);
public static boolean isSafeMod(byte[] message, String safeModTag) {
return (new String(message)).contains(safeModTag);
}
public static boolean checkPath(String path) {
File file = new File(path);
return file.exists();
boolean tag = true;
String [] paths = path.split(",");
for (int i = 0; i < paths.length; i++) {
LOGGER.info(addTag("[check path]:" + paths[i]));
File file = new File(paths[i]);
if (!file.exists()) {
tag = false;
}
}
return tag;
}
}

View File

@ -48,7 +48,7 @@ public class FLCommunication implements IFLCommunication {
private static final Logger LOGGER = Logger.getLogger(FLCommunication.class.toString());
private OkHttpClient client;
private static FLCommunication communication;
private static volatile FLCommunication communication;
private FLCommunication() {
if (flParameter.getTimeOut() != 0) {
@ -109,14 +109,16 @@ public class FLCommunication implements IFLCommunication {
}
public static FLCommunication getInstance() {
if (communication == null) {
FLCommunication localRef = communication;
if (localRef == null) {
synchronized (FLCommunication.class) {
if (communication == null) {
communication = new FLCommunication();
localRef = communication;
if (localRef == null) {
communication = localRef = new FLCommunication();
}
}
}
return communication;
return localRef;
}
@Override

View File

@ -0,0 +1,535 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import com.mindspore.flclient.cipher.BaseUtil;
import com.mindspore.flclient.model.AdInferBert;
import com.mindspore.flclient.model.AdTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.CipherPublicParams;
import mindspore.schema.FLPlan;
import mindspore.schema.ResponseCode;
import mindspore.schema.ResponseFLJob;
import mindspore.schema.ResponseGetModel;
import mindspore.schema.ResponseUpdateModel;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;
import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.ADBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
public class FLLiteClient {
private static final Logger LOGGER = Logger.getLogger(FLLiteClient.class.toString());
private FLCommunication flCommunication;
private FLClientStatus status;
private int retCode;
private static int iteration = 0;
private int iterations = 1;
private int epochs = 1;
private int batchSize = 16;
private int minSecretNum;
private byte[] prime;
private int featureSize;
private int trainDataSize;
private double dpEps = 100;
private double dpDelta = 0.01;
private double dpNormClip = 2.0;
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private SecureProtocol secureProtocol = new SecureProtocol();
private static Map<String, float[]> mapBeforeTrain;
private String nextRequestTime;
public FLLiteClient() {
flCommunication = FLCommunication.getInstance();
}
public int setGlobalParameters(ResponseFLJob flJob) {
FLPlan flPlan = flJob.flPlanConfig();
if (flPlan == null) {
LOGGER.severe(Common.addTag("[startFLJob] the FLPlan get from server is null"));
return -1;
}
iterations = flPlan.iterations();
epochs = flPlan.epochs();
batchSize = flPlan.miniBatch();
String serverMod = flPlan.serverMode();
localFLParameter.setServerMod(serverMod);
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <serverMod> from server: " + serverMod));
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for AdTrainBert: " + batchSize));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
adTrainBert.setBatchSize(batchSize);
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for TrainLenet: " + batchSize));
TrainLenet trainLenet = TrainLenet.getInstance();
trainLenet.setBatchSize(batchSize);
}
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <iterations> from server: " + iterations));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <epochs> from server: " + epochs));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <batchSize> from server: " + batchSize));
CipherPublicParams cipherPublicParams = flPlan.cipher();
String encryptLevel = EncryptLevel.NOT_ENCRYPT.toString();
localFLParameter.setEncryptLevel(encryptLevel);
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <encryptLevel> from server: " + encryptLevel));
switch (localFLParameter.getEncryptLevel()) {
case PW_ENCRYPT:
minSecretNum = cipherPublicParams.t();
int primeLength = cipherPublicParams.primeLength();
prime = new byte[primeLength];
for (int i = 0; i < primeLength; i++) {
prime[i] = (byte) cipherPublicParams.prime(i);
}
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server: " + minSecretNum));
LOGGER.info(Common.addTag("[Encrypt] the prime from server: " + BaseUtil.byte2HexString(prime)));
case DP_ENCRYPT:
dpEps = cipherPublicParams.dpEps();
dpDelta = cipherPublicParams.dpDelta();
dpNormClip = cipherPublicParams.dpNormClip();
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpEps> from server: " + dpEps));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpDelta> from server: " + dpDelta));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpNormClip> from server: " + dpNormClip));
break;
default:
LOGGER.info(Common.addTag("[startFLJob] NotEncrypt, do not set parameter for Encrypt"));
}
return 0;
}
public int getRetCode() {
return retCode;
}
public int getIteration() {
return iteration;
}
public int getIterations() {
return iterations;
}
public int getEpochs() {
return epochs;
}
public int getBatchSize() {
return batchSize;
}
public String getNextRequestTime() {
return nextRequestTime;
}
public void setNextRequestTime(String nextRequestTime) {
this.nextRequestTime = nextRequestTime;
}
public void setTrainDataSize(int trainDataSize) {
this.trainDataSize = trainDataSize;
}
public FLClientStatus checkStatus() {
return this.status;
}
public FLClientStatus startFLJob() {
LOGGER.info(Common.addTag("[startFLJob] ====================================Verify server===================================="));
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[startFLJob] ==============startFLJob url: " + url + "=============="));
StartFLJob startFLJob = StartFLJob.getInstance();
Date date = new Date();
long time = date.getTime();
byte[] msg = startFLJob.getRequestStartFLJob(trainDataSize, iteration, time);
try {
long start = Common.startTime("single startFLJob");
LOGGER.info(Common.addTag("[startFLJob] the request message length: " + msg.length));
byte[] message = flCommunication.syncRequest(url + "/startFLJob", msg);
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[startFLJob] The cluster is in safemode, need wait some time and request again"));
status = FLClientStatus.RESTART;
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
return status;
}
LOGGER.info(Common.addTag("[startFLJob] the response message length: " + message.length));
Common.endTime(start, "single startFLJob");
ByteBuffer buffer = ByteBuffer.wrap(message);
ResponseFLJob responseDataBuf = ResponseFLJob.getRootAsResponseFLJob(buffer);
status = judgeStartFLJob(startFLJob, responseDataBuf);
retCode = responseDataBuf.retcode();
} catch (IOException e) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in StartFLJob: catch IOException: " + e.getMessage()));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
}
return status;
}
public FLClientStatus judgeStartFLJob(StartFLJob startFLJob, ResponseFLJob responseDataBuf) {
iteration = responseDataBuf.iteration();
FLClientStatus response = startFLJob.doResponse(responseDataBuf);
status = response;
switch (response) {
case SUCCESS:
LOGGER.info(Common.addTag("[startFLJob] startFLJob success"));
featureSize = startFLJob.getFeatureSize();
secureProtocol.setEncryptFeatureName(startFLJob.getEncryptFeatureName());
LOGGER.info(Common.addTag("[startFLJob] ***the feature size get in ResponseFLJob***: " + featureSize));
int tag = setGlobalParameters(responseDataBuf);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] setGlobalParameters failed"));
status = FLClientStatus.FAILED;
}
break;
case RESTART:
FLPlan flPlan = responseDataBuf.flPlanConfig();
iterations = flPlan.iterations();
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <iterations> from server: " + iterations));
nextRequestTime = responseDataBuf.nextReqTime();
break;
case FAILED:
LOGGER.severe(Common.addTag("[startFLJob] startFLJob failed"));
break;
default:
LOGGER.severe(Common.addTag("[startFLJob] failed: the response of startFLJob is out of range <SUCCESS, WAIT, FAILED, Restart>"));
status = FLClientStatus.FAILED;
}
return status;
}
public FLClientStatus localTrain() {
LOGGER.info(Common.addTag("[train] ====================================global train epoch " + iteration + "===================================="));
status = FLClientStatus.SUCCESS;
retCode = ResponseCode.SUCCEED;
if (flParameter.getFlName().equals(ADBERT)) {
LOGGER.info(Common.addTag("[train] train in adbert"));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
int tag = adTrainBert.trainModel(flParameter.getTrainModelPath(), epochs);
if (tag == -1) {
LOGGER.severe(Common.addTag("[train] unsolved error code in <adTrainBert.trainModel>"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
}
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("[train] train in lenet"));
TrainLenet trainLenet = TrainLenet.getInstance();
int tag = trainLenet.trainModel(flParameter.getTrainModelPath(), epochs);
if (tag == -1) {
LOGGER.severe(Common.addTag("[train] unsolved error code in <trainLenet.trainModel>"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
}
}
return status;
}
public FLClientStatus updateModel() {
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[updateModel] ==============updateModel url: " + url + "=============="));
UpdateModel updateModelBuf = UpdateModel.getInstance();
byte[] updateModelBuffer = updateModelBuf.getRequestUpdateFLJob(iteration, secureProtocol, trainDataSize);
if (updateModelBuf.getStatus() == FLClientStatus.FAILED) {
LOGGER.info(Common.addTag("[updateModel] catch error in build RequestUpdateFLJob"));
return FLClientStatus.FAILED;
}
try {
long start = Common.startTime("single updateModel");
LOGGER.info(Common.addTag("[updateModel] the request message length: " + updateModelBuffer.length));
byte[] message = flCommunication.syncRequest(url + "/updateModel", updateModelBuffer);
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[updateModel] The cluster is in safemode, need wait some time and request again"));
status = FLClientStatus.RESTART;
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
return status;
}
LOGGER.info(Common.addTag("[updateModel] the response message length: " + message.length));
Common.endTime(start, "single updateModel");
ByteBuffer debugBuffer = ByteBuffer.wrap(message);
ResponseUpdateModel responseDataBuf = ResponseUpdateModel.getRootAsResponseUpdateModel(debugBuffer);
status = updateModelBuf.doResponse(responseDataBuf);
retCode = responseDataBuf.retcode();
if (status == FLClientStatus.RESTART) {
nextRequestTime = responseDataBuf.nextReqTime();
}
LOGGER.info(Common.addTag("[updateModel] get response from server ok!"));
} catch (IOException e) {
LOGGER.severe(Common.addTag("[updateModel] unsolved error code in updateModel: catch IOException: " + e.getMessage()));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
}
return status;
}
public FLClientStatus getModel() {
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[getModel] ===========getModel url: " + url + "=============="));
GetModel getModelBuf = GetModel.getInstance();
byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), iteration);
try {
long start = Common.startTime("single getModel");
LOGGER.info(Common.addTag("[getModel] the request message length: " + buffer.length));
byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[getModel] The cluster is in safemode, need wait some time and request again"));
status = FLClientStatus.WAIT;
return status;
}
LOGGER.info(Common.addTag("[getModel] the response message length: " + message.length));
Common.endTime(start, "single getModel");
LOGGER.info(Common.addTag("[getModel] get model request success"));
ByteBuffer debugBuffer = ByteBuffer.wrap(message);
ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer);
status = getModelBuf.doResponse(responseDataBuf);
retCode = responseDataBuf.retcode();
if (status == FLClientStatus.RESTART) {
nextRequestTime = responseDataBuf.timestamp();
}
LOGGER.info(Common.addTag("[getModel] get response from server ok!"));
} catch (IOException e) {
LOGGER.severe(Common.addTag("[getModel] un sloved error code: catch IOException: " + e.getMessage()));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
}
return status;
}
public static synchronized Map<String, float[]> getOldMapCopy(Map<String, float[]> map) {
if (mapBeforeTrain == null) {
Map<String, float[]> copyMap = new TreeMap<>();
for (String key : map.keySet()) {
float[] data = map.get(key);
int dataLen = data.length;
float[] weights = new float[dataLen];
if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) && (key.indexOf("learning") < 0)) {
for (int j = 0; j < dataLen; j++) {
float weight = data[j];
weights[j] = weight;
}
copyMap.put(key, weights);
}
}
mapBeforeTrain = copyMap;
} else {
for (String key : map.keySet()) {
float[] data = map.get(key);
float[] copyData = mapBeforeTrain.get(key);
int dataLen = data.length;
if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) && (key.indexOf("learning") < 0)) {
for (int j = 0; j < dataLen; j++) {
copyData[j] = data[j];
}
}
}
}
return mapBeforeTrain;
}
public FLClientStatus getFeatureMask() {
FLClientStatus curStatus;
switch (localFLParameter.getEncryptLevel()) {
case PW_ENCRYPT:
LOGGER.info(Common.addTag("[Encrypt] creating feature mask of <" + localFLParameter.getEncryptLevel().toString() + ">"));
secureProtocol.setPWParameter(iteration, minSecretNum, prime, featureSize);
curStatus = secureProtocol.pwCreateMask();
if (curStatus == FLClientStatus.RESTART) {
nextRequestTime = secureProtocol.getNextRequestTime();
}
retCode = secureProtocol.getRetCode();
LOGGER.info(Common.addTag("[Encrypt] the response of create mask for <" + localFLParameter.getEncryptLevel().toString() + "> : " + curStatus));
return curStatus;
case DP_ENCRYPT:
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals(ADBERT)) {
AdTrainBert adTrainBert = AdTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
}
Map<String, float[]> copyMap = getOldMapCopy(map);
curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClip, copyMap);
retCode = ResponseCode.SUCCEED;
if (curStatus != FLClientStatus.SUCCESS) {
LOGGER.info(Common.addTag("---Differential privacy init failed---"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[Encrypt] set parameters for DPEncrypt!"));
return FLClientStatus.SUCCESS;
case NOT_ENCRYPT:
retCode = ResponseCode.SUCCEED;
LOGGER.info(Common.addTag("[Encrypt] don't mask model"));
return FLClientStatus.SUCCESS;
default:
retCode = ResponseCode.SUCCEED;
LOGGER.severe(Common.addTag("[Encrypt] The encrypt level is error, not encrypt by default"));
return FLClientStatus.SUCCESS;
}
}
public FLClientStatus unMasking() {
FLClientStatus curStatus;
switch (localFLParameter.getEncryptLevel()) {
case PW_ENCRYPT:
curStatus = secureProtocol.pwUnmasking();
retCode = secureProtocol.getRetCode();
LOGGER.info(Common.addTag("[Encrypt] the response of unmasking : " + curStatus));
if (curStatus == FLClientStatus.RESTART) {
nextRequestTime = secureProtocol.getNextRequestTime();
}
return curStatus;
case DP_ENCRYPT:
LOGGER.info(Common.addTag("[Encrypt] DPEncrypt do not need unmasking"));
retCode = ResponseCode.SUCCEED;
return FLClientStatus.SUCCESS;
case NOT_ENCRYPT:
LOGGER.info(Common.addTag("[Encrypt] haven't mask model"));
retCode = ResponseCode.SUCCEED;
return FLClientStatus.SUCCESS;
default:
LOGGER.severe(Common.addTag("[Encrypt] The encrypt level is error, not encrypt by default"));
retCode = ResponseCode.SUCCEED;
return FLClientStatus.SUCCESS;
}
}
public FLClientStatus evaluateModel() {
status = FLClientStatus.SUCCESS;
retCode = ResponseCode.SUCCEED;
LOGGER.info(Common.addTag("===================================test combine model from server==================================="));
if (flParameter.getFlName().equals(ADBERT)) {
AdInferBert adInferBert = AdInferBert.getInstance();
int dataSize = adInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile(), true);
if (dataSize <= 0) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <adTrainBert.initDataSet>: the return dataSize<=0"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
}
float acc = adInferBert.evalModel();
if (acc == Float.NaN) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <adTrainBert.evalModel>: the return acc is NAN"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
}
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " + flParameter.getTestDataset() + " vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile()));
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
int dataSize = trainLenet.initDataSet(flParameter.getTestDataset().split(",")[0], flParameter.getTestDataset().split(",")[1]);
if (dataSize <= 0) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.initDataSet>: the return dataSize<=0"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
}
float acc = trainLenet.evalModel();
if (acc == Float.NaN) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.evalModel>: the return acc is NAN"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
}
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " + flParameter.getTestDataset().split(",")[0] + " labelPath: " + flParameter.getTestDataset().split(",")[1]));
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
}
return status;
}
/**
* @param dataPath, train or test dataset and label set
*/
public int setInput(String dataPath) {
retCode = ResponseCode.SUCCEED;
LOGGER.info(Common.addTag("==========set input==========="));
int dataSize = 0;
if (flParameter.getFlName().equals(ADBERT)) {
AdTrainBert adTrainBert = AdTrainBert.getInstance();
dataSize = adTrainBert.initDataSet(dataPath, flParameter.getVocabFile(), flParameter.getIdsFile());
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath + " dataSize: " + +dataSize + " vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile()));
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
dataSize = trainLenet.initDataSet(dataPath.split(",")[0], dataPath.split(",")[1]);
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath.split(",")[0] + " dataSize: " + +dataSize + " labelPath: " + dataPath.split(",")[1]));
}
if (dataSize <= 0) {
retCode = ResponseCode.RequestError;
return -1;
}
return dataSize;
}
public FLClientStatus initSession() {
int tag = 0;
retCode = ResponseCode.SUCCEED;
if (flParameter.getFlName().equals(ADBERT)) {
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
tag = adTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session============="));
AdInferBert adInferBert = AdInferBert.getInstance();
tag = adInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
TrainLenet trainLenet = TrainLenet.getInstance();
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
}
if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
return FLClientStatus.SUCCESS;
}
@Override
protected void finalize() {
if (flParameter.getFlName().equals(ADBERT)) {
LOGGER.info(Common.addTag("===========free train session============="));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
SessionUtil.free(adTrainBert.getTrainSession());
if (!flParameter.getTestDataset().equals("null")) {
LOGGER.info(Common.addTag("===========free inference session============="));
AdInferBert adInferBert = AdInferBert.getInstance();
SessionUtil.free(adInferBert.getTrainSession());
}
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("===========free session============="));
TrainLenet trainLenet = TrainLenet.getInstance();
SessionUtil.free(trainLenet.getTrainSession());
}
}
}

View File

@ -0,0 +1,307 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import java.util.logging.Logger;
public class FLParameter {
private static final Logger LOGGER = Logger.getLogger(FLParameter.class.toString());
public static final int TIME_OUT = 100;
public static final int SLEEP_TIME = 1000;
private String hostName;
private String certPath;
private String trainDataset;
private String vocabFile = "null";
private String idsFile = "null";
private String testDataset = "null";
private String flName;
private String trainModelPath;
private String inferModelPath;
private String clientID;
private String ip;
private int port;
private boolean useSSL = false;
private int timeOut;
private int sleepTime;
private boolean useElb = false;
private int serverNum = 1;
private boolean timer = true;
private int timeWindow = 6000;
private int reRequestNum = timeWindow / SLEEP_TIME + 1;
private static volatile FLParameter flParameter;
public static FLParameter getInstance() {
FLParameter localRef = flParameter;
if (localRef == null) {
synchronized (FLParameter.class) {
localRef = flParameter;
if (localRef == null) {
flParameter = localRef = new FLParameter();
}
}
}
return localRef;
}
public String getHostName() {
if ("".equals(hostName) || hostName.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <hostName> is null, please set it before use"));
throw new RuntimeException();
}
return hostName;
}
public void setHostName(String hostName) {
this.hostName = hostName;
}
public String getCertPath() {
if ("".equals(certPath) || certPath.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is null, please set it before use"));
throw new RuntimeException();
}
return certPath;
}
public void setCertPath(String certPath) {
this.certPath = certPath;
}
public String getTrainDataset() {
if ("".equals(trainDataset) || trainDataset.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is null, please set it before use"));
throw new RuntimeException();
}
return trainDataset;
}
public void setTrainDataset(String trainDataset) {
if (Common.checkPath(trainDataset)) {
this.trainDataset = trainDataset;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is not exist, please check it before set"));
throw new RuntimeException();
}
}
public String getVocabFile() {
if ("null".equals(vocabFile) && "adbert".equals(flName)) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is null, please set it before use"));
throw new RuntimeException();
}
return vocabFile;
}
public void setVocabFile(String vocabFile) {
if (Common.checkPath(vocabFile)) {
this.vocabFile = vocabFile;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is not exist, please check it before set"));
throw new RuntimeException();
}
}
public String getIdsFile() {
if ("null".equals(idsFile) && "adbert".equals(flName)) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is null, please set it before use"));
throw new RuntimeException();
}
return idsFile;
}
public void setIdsFile(String idsFile) {
if (Common.checkPath(idsFile)) {
this.idsFile = idsFile;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is not exist, please check it before set"));
throw new RuntimeException();
}
}
public String getTestDataset() {
return testDataset;
}
public void setTestDataset(String testDataset) {
if (Common.checkPath(testDataset)) {
this.testDataset = testDataset;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <testDataset> is not exist, please check it before set"));
throw new RuntimeException();
}
}
public String getFlName() {
if ("".equals(flName) || flName.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is null, please set it before use"));
throw new RuntimeException();
}
return flName;
}
public void setFlName(String flName) {
if (Common.checkFLName(flName)) {
this.flName = flName;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is not in flNameTrustList, please check it before set"));
throw new RuntimeException();
}
}
public String getTrainModelPath() {
if ("".equals(trainModelPath) || trainModelPath.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is null, please set it before use"));
throw new RuntimeException();
}
return trainModelPath;
}
public void setTrainModelPath(String trainModelPath) {
if (Common.checkPath(trainModelPath)) {
this.trainModelPath = trainModelPath;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is not exist, please check it before set"));
throw new RuntimeException();
}
}
public String getInferModelPath() {
if ("".equals(inferModelPath) || inferModelPath.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is null, please set it before use"));
throw new RuntimeException();
}
return inferModelPath;
}
public void setInferModelPath(String inferModelPath) {
if (Common.checkPath(inferModelPath)) {
this.inferModelPath = inferModelPath;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is not exist, please check it before set"));
throw new RuntimeException();
}
}
public String getIp() {
if ("".equals(ip) || ip.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <ip> is null, please set it before use"));
throw new RuntimeException();
}
return ip;
}
public void setIp(String ip) {
this.ip = ip;
}
public boolean isUseSSL() {
return useSSL;
}
public void setUseSSL(boolean useSSL) {
this.useSSL = useSSL;
}
public int getPort() {
if (port == 0) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <port> is null, please set it before use"));
throw new RuntimeException();
}
return port;
}
public void setPort(int port) {
this.port = port;
}
public int getTimeOut() {
return timeOut;
}
public void setTimeOut(int timeOut) {
this.timeOut = timeOut;
}
public int getSleepTime() {
return sleepTime;
}
public void setSleepTime(int sleepTime) {
this.sleepTime = sleepTime;
}
public boolean isUseElb() {
return useElb;
}
public void setUseElb(boolean useElb) {
this.useElb = useElb;
}
public int getServerNum() {
if (serverNum == 0) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <serverNum> is zero, please set it before use"));
throw new RuntimeException();
}
return serverNum;
}
public void setServerNum(int serverNum) {
this.serverNum = serverNum;
}
public boolean isTimer() {
return timer;
}
public void setTimer(boolean timer) {
this.timer = timer;
}
public int getTimeWindow() {
return timeWindow;
}
public void setTimeWindow(int timeWindow) {
this.timeWindow = timeWindow;
}
public int getReRequestNum() {
return reRequestNum;
}
public void setReRequestNum(int reRequestNum) {
this.reRequestNum = reRequestNum;
}
public String getClientID() {
if ("".equals(clientID) || clientID.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <clientID> is null, please set it before use"));
throw new RuntimeException();
}
return clientID;
}
public void setClientID(String clientID) {
this.clientID = clientID;
}
}

View File

@ -0,0 +1,184 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AdInferBert;
import com.mindspore.flclient.model.AdTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.FeatureMap;
import mindspore.schema.RequestGetModel;
import mindspore.schema.ResponseCode;
import mindspore.schema.ResponseGetModel;
import java.util.ArrayList;
import java.util.Date;
import java.util.logging.Logger;
public class GetModel {
static {
System.loadLibrary("mindspore-lite-jni");
}
class RequestGetModelBuilder {
private FlatBufferBuilder builder;
private int nameOffset = 0;
private int iteration = 0;
private int timeStampOffset = 0;
public RequestGetModelBuilder() {
builder = new FlatBufferBuilder();
}
public RequestGetModelBuilder flName(String name) {
this.nameOffset = this.builder.createString(name);
return this;
}
public RequestGetModelBuilder time() {
Date date = new Date();
long time = date.getTime();
this.timeStampOffset = builder.createString(String.valueOf(time));
return this;
}
public RequestGetModelBuilder iteration(int iteration) {
this.iteration = iteration;
return this;
}
public byte[] build() {
int root = RequestGetModel.createRequestGetModel(this.builder, nameOffset, iteration, timeStampOffset);
builder.finish(root);
return builder.sizedByteArray();
}
}
private static final Logger LOGGER = Logger.getLogger(GetModel.class.toString());
private static GetModel getModel;
private GetModel() {
}
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
public static GetModel getInstance() {
if (getModel == null) {
getModel = new GetModel();
}
return getModel;
}
public byte[] getRequestGetModel(String name, int iteration) {
RequestGetModelBuilder builder = new RequestGetModelBuilder();
return builder.iteration(iteration).flName(name).time().build();
}
private FLClientStatus parseResponseAdbert(ResponseGetModel responseDataBuf) {
int fmCount = responseDataBuf.featureMapLength();
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = responseDataBuf.featureMap(i);
String featureName = feature.weightFullname();
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
albertFeatureMaps.add(feature);
inferFeatureMaps.add(feature);
} else if (localFLParameter.getClassifierWeightName().contains(featureName)) {
inferFeatureMaps.add(feature);
} else {
continue;
}
LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
}
int tag = 0;
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference model-----------------"));
AdInferBert adInferBert = AdInferBert.getInstance();
tag = SessionUtil.updateFeatures(adInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into train model-----------------"));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
tag = SessionUtil.updateFeatures(adTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
return FLClientStatus.SUCCESS;
}
private FLClientStatus parseResponseLenet(ResponseGetModel responseDataBuf) {
int fmCount = responseDataBuf.featureMapLength();
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = responseDataBuf.featureMap(i);
String featureName = feature.weightFullname();
featureMaps.add(feature);
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " + feature.dataLength()));
}
int tag = 0;
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
TrainLenet trainLenet = TrainLenet.getInstance();
tag = SessionUtil.updateFeatures(trainLenet.getTrainSession(), flParameter.getTrainModelPath(), featureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
return FLClientStatus.SUCCESS;
}
public FLClientStatus doResponse(ResponseGetModel responseDataBuf) {
LOGGER.info(Common.addTag("[getModel] ==========get model content is:================"));
LOGGER.info(Common.addTag("[getModel] ==========retCode: " + responseDataBuf.retcode()));
LOGGER.info(Common.addTag("[getModel] ==========reason: " + responseDataBuf.reason()));
LOGGER.info(Common.addTag("[getModel] ==========iteration: " + responseDataBuf.iteration()));
LOGGER.info(Common.addTag("[getModel] ==========time: " + responseDataBuf.timestamp()));
FLClientStatus status = FLClientStatus.SUCCESS;
int retCode = responseDataBuf.retcode();
switch (retCode) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[getModel] getModel response success"));
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseAdbert>"));
status = parseResponseAdbert(responseDataBuf);
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>"));
status = parseResponseLenet(responseDataBuf);
}
return status;
case (ResponseCode.SucNotReady):
LOGGER.info(Common.addTag("[getModel] server is not ready now: need wait and request getModel again"));
return FLClientStatus.WAIT;
case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[getModel] out of time: need wait and request startFLJob again"));
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
case (ResponseCode.SystemError):
LOGGER.warning(Common.addTag("[getModel] catch RequestError or SystemError"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[getModel] the return <retCode> from server is invalid: " + retCode));
return FLClientStatus.FAILED;
}
}
}

View File

@ -0,0 +1,125 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;
public class LocalFLParameter {
private static final Logger LOGGER = Logger.getLogger(LocalFLParameter.class.toString());
public static final int SEED_SIZE = 32;
public static final int IVEC_LEN = 16;
public static final String LENET = "lenet";
public static final String ADBERT = "adbert";
private List<String> classifierWeightName = new ArrayList<>();
private List<String> albertWeightName = new ArrayList<>();
private String flID;
private String encryptLevel = "NotEncrypt";
private String earlyStopMod = "NotEarlyStop";
private String serverMod = ServerMod.HYBRID_TRAINING.toString();
private String safeMod = "The cluster is in safemode.";
private static volatile LocalFLParameter localFLParameter;
private LocalFLParameter() {
// set classifierWeightName albertWeightName
Common.setClassifierWeightName(classifierWeightName);
Common.setAlbertWeightName(albertWeightName);
}
public static synchronized LocalFLParameter getInstance() {
LocalFLParameter localRef = localFLParameter;
if (localRef == null) {
synchronized (LocalFLParameter.class) {
localRef = localFLParameter;
if (localRef == null) {
localFLParameter = localRef = new LocalFLParameter();
}
}
}
return localRef;
}
public List<String> getClassifierWeightName() {
if (classifierWeightName.isEmpty()) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please set it before use"));
throw new RuntimeException();
}
return classifierWeightName;
}
public void setClassifierWeightName(List<String> classifierWeightName) {
this.classifierWeightName = classifierWeightName;
}
public List<String> getAlbertWeightName() {
if (albertWeightName.isEmpty()) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please set it before use"));
throw new RuntimeException();
}
return albertWeightName;
}
public void setAlbertWeightName(List<String> albertWeightName) {
this.albertWeightName = albertWeightName;
}
public String getFlID() {
if ("".equals(flID) || flID == null) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <flID> is null, please set it before use"));
throw new RuntimeException();
}
return flID;
}
public void setFlID(String flID) {
this.flID = flID;
}
public EncryptLevel getEncryptLevel() {
return EncryptLevel.valueOf(encryptLevel);
}
public void setEncryptLevel(String encryptLevel) {
this.encryptLevel = encryptLevel;
}
public EarlyStopMod getEarlyStopMod() {
return EarlyStopMod.valueOf(earlyStopMod);
}
public void setEarlyStopMod(String earlyStopMod) {
this.earlyStopMod = earlyStopMod;
}
public String getServerMod() {
return serverMod;
}
public void setServerMod(String serverMod) {
this.serverMod = serverMod;
}
public String getSafeMod() {
return safeMod;
}
public void setSafeMod(String safeMod) {
this.safeMod = safeMod;
}
}

View File

@ -42,6 +42,7 @@ public class SecureProtocol {
private static double deltaError = 1e-6;
private static Map<String, float[]> modelMap;
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
private int retCode;
public FLClientStatus getStatus() {
return status;
@ -51,6 +52,10 @@ public class SecureProtocol {
return featureMask;
}
public int getRetCode() {
return retCode;
}
public SecureProtocol() {
}
@ -91,6 +96,7 @@ public class SecureProtocol {
LOGGER.info("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() + "==============");
// round 0
status = cipher.exchangeKeys();
retCode = cipher.getRetCode();
LOGGER.info("[PairWiseMask] ============= RequestExchangeKeys+GetExchangeKeys response: " + status + "============");
if (status != FLClientStatus.SUCCESS) {
return status;
@ -98,6 +104,7 @@ public class SecureProtocol {
// round 1
try {
status = cipher.shareSecrets();
retCode = cipher.getRetCode();
LOGGER.info("[Encrypt] =============RequestShareSecrets+GetShareSecrets response: " + status + "=============");
} catch (Exception e) {
LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask");
@ -109,6 +116,7 @@ public class SecureProtocol {
// round2
try {
featureMask = cipher.doubleMaskingWeight();
retCode = cipher.getRetCode();
LOGGER.info("[Encrypt] =============Create double feature mask: SUCCESS=============");
} catch (Exception e) {
LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask");
@ -151,6 +159,7 @@ public class SecureProtocol {
public FLClientStatus pwUnmasking() {
status = cipher.reconstructSecrets(); // round3
retCode = cipher.getRetCode();
LOGGER.info("[Encrypt] =============GetClientList+SendReconstructSecret: " + status + "=============");
return status;
}

View File

@ -0,0 +1,230 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AdInferBert;
import com.mindspore.flclient.model.AdTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.FeatureMap;
import mindspore.schema.RequestFLJob;
import mindspore.schema.ResponseCode;
import mindspore.schema.ResponseFLJob;
import java.util.ArrayList;
import java.util.logging.Logger;
public class StartFLJob {
static {
System.loadLibrary("mindspore-lite-jni");
}
private static final Logger LOGGER = Logger.getLogger(StartFLJob.class.toString());
class RequestStartFLJobBuilder {
private RequestFLJob requestFLJob;
private FlatBufferBuilder builder;
private int nameOffset = 0;
private int iteration = 0;
private int dataSize = 0;
private int timestampOffset = 0;
private int idOffset = 0;
public RequestStartFLJobBuilder() {
builder = new FlatBufferBuilder();
}
public RequestStartFLJobBuilder flName(String name) {
this.nameOffset = this.builder.createString(name);
return this;
}
public RequestStartFLJobBuilder id(String id) {
this.idOffset = this.builder.createString(id);
return this;
}
public RequestStartFLJobBuilder time(long timestamp) {
this.timestampOffset = builder.createString(String.valueOf(timestamp));
return this;
}
public RequestStartFLJobBuilder dataSize(int dataSize) {
// temp code need confirm
this.dataSize = dataSize;
LOGGER.info(Common.addTag("[startFLJob] the train data size: " + dataSize));
return this;
}
public RequestStartFLJobBuilder iteration(int iteration) {
this.iteration = iteration;
return this;
}
public byte[] build() {
int root = RequestFLJob.createRequestFLJob(this.builder, this.nameOffset, this.idOffset, this.iteration,
this.dataSize, this.timestampOffset);
builder.finish(root);
return builder.sizedByteArray();
}
}
private static StartFLJob startFLJob;
private FLClientStatus status;
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private int featureSize;
private String nextRequestTime;
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
private StartFLJob() {
}
public static StartFLJob getInstance() {
if (startFLJob == null) {
startFLJob = new StartFLJob();
}
return startFLJob;
}
public String getNextRequestTime() {
return nextRequestTime;
}
public byte[] getRequestStartFLJob(int dataSize, int iteration, long time) {
RequestStartFLJobBuilder builder = new RequestStartFLJobBuilder();
return builder.flName(flParameter.getFlName())
.time(time)
.id(localFLParameter.getFlID())
.dataSize(dataSize)
.iteration(iteration)
.build();
}
public int getFeatureSize() {
return featureSize;
}
public ArrayList<String> getEncryptFeatureName() {
return encryptFeatureName;
}
private FLClientStatus parseResponseAdbert(ResponseFLJob flJob) {
int fmCount = flJob.featureMapLength();
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
encryptFeatureName.clear();
if (fmCount <= 0) {
LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero"));
return FLClientStatus.FAILED;
}
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
String featureName = feature.weightFullname();
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
albertFeatureMaps.add(feature);
inferFeatureMaps.add(feature);
featureSize += feature.dataLength();
encryptFeatureName.add(feature.weightFullname());
} else if (localFLParameter.getClassifierWeightName().contains(featureName)) {
inferFeatureMaps.add(feature);
} else {
continue;
}
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
}
int tag = 0;
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference model-----------------"));
AdInferBert adInferBert = AdInferBert.getInstance();
tag = SessionUtil.updateFeatures(adInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------"));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
tag = SessionUtil.updateFeatures(adTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
return FLClientStatus.SUCCESS;
}
private FLClientStatus parseResponseLenet(ResponseFLJob flJob) {
int fmCount = flJob.featureMapLength();
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
encryptFeatureName.clear();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
String featureName = feature.weightFullname();
featureMaps.add(feature);
featureSize += feature.dataLength();
encryptFeatureName.add(featureName);
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
}
int tag = 0;
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
TrainLenet trainLenet = TrainLenet.getInstance();
tag = SessionUtil.updateFeatures(trainLenet.getTrainSession(), flParameter.getTrainModelPath(), featureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
return FLClientStatus.SUCCESS;
}
public FLClientStatus doResponse(ResponseFLJob flJob) {
LOGGER.info(Common.addTag("[startFLJob] return code: " + flJob.retcode()));
LOGGER.info(Common.addTag("[startFLJob] reason: " + flJob.reason()));
LOGGER.info(Common.addTag("[startFLJob] iteration: " + flJob.iteration()));
LOGGER.info(Common.addTag("[startFLJob] is selected: " + flJob.isSelected()));
LOGGER.info(Common.addTag("[startFLJob] next request time: " + flJob.nextReqTime()));
nextRequestTime = flJob.nextReqTime();
LOGGER.info(Common.addTag("[startFLJob] timestamp: " + flJob.timestamp()));
int retcode = flJob.retcode();
switch (retcode) {
case (ResponseCode.SUCCEED):
localFLParameter.setServerMod(flJob.flPlanConfig().serverMode());
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseAdbert>"));
parseResponseAdbert(flJob);
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseLenet>"));
parseResponseLenet(flJob);
}
return FLClientStatus.SUCCESS;
case (ResponseCode.OutOfTime):
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
case (ResponseCode.SystemError):
LOGGER.info(Common.addTag("[startFLJob] catch RequestError or SystemError"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[startFLJob] the return <retcode> from server is invalid: " + retcode));
return FLClientStatus.FAILED;
}
}
public FLClientStatus getStatus() {
return this.status;
}
}

View File

@ -0,0 +1,322 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import com.mindspore.flclient.model.AdInferBert;
import com.mindspore.flclient.model.AdTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.ResponseGetModel;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.ADBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
public class SyncFLJob {
private static final Logger LOGGER = Logger.getLogger(SyncFLJob.class.toString());
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private FLJobResultCallback flJobResultCallback = new FLJobResultCallback();
public SyncFLJob() {
}
public void flJobRun() {
localFLParameter.setFlID(flParameter.getClientID());
FLLiteClient client = new FLLiteClient();
client.initSession();
FLClientStatus curStatus;
do {
LOGGER.info(Common.addTag("flName: " + flParameter.getFlName()));
int trainDataSize = client.setInput(flParameter.getTrainDataset());
LOGGER.info(Common.addTag("train path: " + flParameter.getTrainDataset()));
LOGGER.info(Common.addTag("train data size: " + trainDataSize));
if (trainDataSize <= 0) {
LOGGER.severe(Common.addTag("unsolved error code in <client.setInput>: the return trainDataSize<=0"));
curStatus = FLClientStatus.FAILED;
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(), client.getRetCode());
break;
}
client.setTrainDataSize(trainDataSize);
// startFLJob
curStatus = client.startFLJob();
while (curStatus == FLClientStatus.WAIT) {
waitSomeTime();
curStatus = client.startFLJob();
}
if (curStatus == FLClientStatus.RESTART) {
restart("[startFLJob]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
continue;
} else if (curStatus == FLClientStatus.FAILED) {
failed("[startFLJob]", client.getIteration(), client.getRetCode(), curStatus);
break;
}
LOGGER.info(Common.addTag("[startFLJob] startFLJob succeed, curIteration: " + client.getIteration()));
// create mask
curStatus = client.getFeatureMask();
if (curStatus == FLClientStatus.RESTART) {
restart("[Encrypt] creatMask", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
continue;
} else if (curStatus == FLClientStatus.FAILED) {
failed("[Encrypt] createMask", client.getIteration(), client.getRetCode(), curStatus);
break;
}
// train
curStatus = client.localTrain();
if (curStatus == FLClientStatus.FAILED) {
failed("[train] train", client.getIteration(), client.getRetCode(), curStatus);
break;
}
LOGGER.info(Common.addTag("[train] train succeed"));
// updateModel
curStatus = client.updateModel();
while (curStatus == FLClientStatus.WAIT) {
waitSomeTime();
curStatus = client.updateModel();
}
if (curStatus == FLClientStatus.RESTART) {
restart("[updateModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
continue;
} else if (curStatus == FLClientStatus.FAILED) {
failed("[updateModel] updateModel", client.getIteration(), client.getRetCode(), curStatus);
break;
}
LOGGER.info(Common.addTag("[updateModel] updateModel succeed"));
// unmasking
curStatus = client.unMasking();
if (curStatus == FLClientStatus.RESTART) {
restart("[Encrypt] unmasking", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
continue;
} else if (curStatus == FLClientStatus.FAILED) {
failed("[Encrypt] unmasking", client.getIteration(), client.getRetCode(), curStatus);
break;
}
// getModel
curStatus = client.getModel();
while (curStatus == FLClientStatus.WAIT) {
waitSomeTime();
curStatus = client.getModel();
}
if (curStatus == FLClientStatus.RESTART) {
restart("[getModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
continue;
} else if (curStatus == FLClientStatus.FAILED) {
failed("[getModel] getModel", client.getIteration(), client.getRetCode(), curStatus);
break;
}
LOGGER.info(Common.addTag("[getModel] getModel succeed"));
//evaluate model after getting model from server
if (flParameter.getTestDataset().equals("null")) {
LOGGER.info(Common.addTag("[evaluate] the testDataset is null, don't evaluate the combine model"));
} else {
curStatus = client.evaluateModel();
if (curStatus == FLClientStatus.FAILED) {
failed("[evaluate] evaluate", client.getIteration(), client.getRetCode(), curStatus);
break;
}
LOGGER.info(Common.addTag("[evaluate] evaluate succeed"));
}
LOGGER.info(Common.addTag("========================================================the total response of " + client.getIteration() + ": " + curStatus + "======================================================================"));
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(), client.getRetCode());
} while (client.getIteration() < client.getIterations());
client.finalize();
LOGGER.info(Common.addTag("flJobRun finish"));
flJobResultCallback.onFlJobFinished(flParameter.getFlName(), client.getIterations(), client.getRetCode());
}
public int[] modelInference(String flName, String dataPath, String vocabFile, String idsFile, String modelPath) {
int[] labels = new int[0];
if (flName.equals(ADBERT)) {
AdInferBert adInferBert = AdInferBert.getInstance();
LOGGER.info(Common.addTag("===========model inference============="));
labels = adInferBert.inferModel(modelPath, dataPath, vocabFile, idsFile);
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
SessionUtil.free(adInferBert.getTrainSession());
LOGGER.info(Common.addTag("[model inference] inference finish"));
} else if (flName.equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
LOGGER.info(Common.addTag("===========model inference============="));
labels = trainLenet.inferModel(modelPath, dataPath.split(",")[0]);
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
SessionUtil.free(trainLenet.getTrainSession());
LOGGER.info(Common.addTag("[model inference] inference finish"));
}
if (labels.length == 0) {
LOGGER.severe(Common.addTag("[model inference] the return labels is null."));
}
return labels;
}
public FLClientStatus getModel(boolean useElb, int serverNum, String ip, int port, String flName, String trainModelPath, String inferModelPath, boolean useSSL) {
int tag = 0;
flParameter.setTrainModelPath(trainModelPath);
flParameter.setInferModelPath(inferModelPath);
FLClientStatus status = FLClientStatus.SUCCESS;
try {
if (flName.equals(ADBERT)) {
localFLParameter.setServerMod(ServerMod.HYBRID_TRAINING.toString());
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + trainModelPath + " Create Train Session============="));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
tag = adTrainBert.initSessionAndInputs(trainModelPath, true);
if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " + inferModelPath + " Create inference Session============="));
AdInferBert adInferBert = AdInferBert.getInstance();
tag = adInferBert.initSessionAndInputs(inferModelPath, false);
} else if (flName.equals(LENET)) {
localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString());
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + trainModelPath + " Create Train Session============="));
TrainLenet trainLenet = TrainLenet.getInstance();
tag = trainLenet.initSessionAndInputs(trainModelPath, true);
}
if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
return FLClientStatus.FAILED;
}
flParameter.setUseSSL(useSSL);
FLCommunication flCommunication = FLCommunication.getInstance();
String url = Common.generateUrl(useElb, ip, port, serverNum);
LOGGER.info(Common.addTag("[getModel] ===========getModel url: " + url + "=============="));
GetModel getModelBuf = GetModel.getInstance();
byte[] buffer = getModelBuf.getRequestGetModel(flName, 0);
byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
LOGGER.info(Common.addTag("[getModel] get model request success"));
ByteBuffer debugBuffer = ByteBuffer.wrap(message);
ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer);
status = getModelBuf.doResponse(responseDataBuf);
LOGGER.info(Common.addTag("[getModel] success!"));
} catch (Exception e) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + e.getMessage()));
status = FLClientStatus.FAILED;
}
if (flName.equals(ADBERT)) {
LOGGER.info(Common.addTag("===========free train session============="));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
SessionUtil.free(adTrainBert.getTrainSession());
LOGGER.info(Common.addTag("===========free inference session============="));
AdInferBert adInferBert = AdInferBert.getInstance();
SessionUtil.free(adInferBert.getTrainSession());
} else if (flName.equals(LENET)) {
LOGGER.info(Common.addTag("===========free session============="));
TrainLenet trainLenet = TrainLenet.getInstance();
SessionUtil.free(trainLenet.getTrainSession());
}
return status;
}
private void waitSomeTime() {
if (flParameter.getSleepTime() != 0)
Common.sleep(flParameter.getSleepTime());
else
Common.sleep(SLEEP_TIME);
}
private void waitNextReqTime(String nextReqTime) {
if (flParameter.isTimer()) {
long waitTime = Common.getWaitTime(nextReqTime);
Common.sleep(waitTime);
} else {
waitSomeTime();
}
}
private void restart(String tag, String nextReqTime, int iteration, int retcode) {
LOGGER.info(Common.addTag(tag + " out of time: need wait and request startFLJob again"));
waitNextReqTime(nextReqTime);
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), iteration, retcode);
}
private void failed(String tag, int iteration, int retcode, FLClientStatus curStatus) {
LOGGER.info(Common.addTag(tag + " failed"));
LOGGER.info(Common.addTag("========================================================the total response of " + iteration + ": " + curStatus + "======================================================================"));
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), iteration, retcode);
}
public static void main(String[] args) {
String trainDataset = args[0];
String vocabFile = args[1];
String idsFile = args[2];
String testDataset = args[3];
String flName = args[4];
String trainModelPath = args[5];
String inferModelPath = args[6];
String clientID = args[7];
String ip = args[8];
boolean useSSL = Boolean.parseBoolean(args[9]);
int port = Integer.parseInt(args[10]);
int timeWindow = Integer.parseInt(args[11]);
boolean useElb = Boolean.parseBoolean(args[12]);
int serverNum = Integer.parseInt(args[13]);
String task = args[14];
FLParameter flParameter = FLParameter.getInstance();
LOGGER.info(Common.addTag("[args] trainDataset: " + trainDataset));
LOGGER.info(Common.addTag("[args] vocabFile: " + vocabFile));
LOGGER.info(Common.addTag("[args] idsFile: " + idsFile));
LOGGER.info(Common.addTag("[args] testDataset: " + testDataset));
LOGGER.info(Common.addTag("[args] flName: " + flName));
LOGGER.info(Common.addTag("[args] trainModelPath: " + trainModelPath));
LOGGER.info(Common.addTag("[args] inferModelPath: " + inferModelPath));
LOGGER.info(Common.addTag("[args] clientID: " + clientID));
LOGGER.info(Common.addTag("[args] ip: " + ip));
LOGGER.info(Common.addTag("[args] useSSL: " + useSSL));
LOGGER.info(Common.addTag("[args] port: " + port));
LOGGER.info(Common.addTag("[args] timeWindow: " + timeWindow));
LOGGER.info(Common.addTag("[args] useElb: " + useElb));
LOGGER.info(Common.addTag("[args] serverNum: " + serverNum));
LOGGER.info(Common.addTag("[args] task: " + task));
flParameter.setClientID(clientID);
SyncFLJob syncFLJob = new SyncFLJob();
if (task.equals("train")) {
flParameter.setTrainDataset(trainDataset);
flParameter.setFlName(flName);
flParameter.setTrainModelPath(trainModelPath);
flParameter.setTestDataset(testDataset);
flParameter.setInferModelPath(inferModelPath);
flParameter.setIp(ip);
flParameter.setUseSSL(useSSL);
flParameter.setPort(port);
flParameter.setTimeWindow(timeWindow);
flParameter.setUseElb(useElb);
flParameter.setServerNum(serverNum);
if (ADBERT.equals(flName)) {
flParameter.setVocabFile(vocabFile);
flParameter.setIdsFile(idsFile);
}
syncFLJob.flJobRun();
} else if (task.equals("inference")) {
syncFLJob.modelInference(flName, testDataset, vocabFile, idsFile, inferModelPath);
} else if (task.equals("getModel")) {
syncFLJob.getModel(false, 1, ip, port, flName, trainModelPath, inferModelPath, false);
} else {
LOGGER.info(Common.addTag("do not do any thing!"));
}
}
}

View File

@ -0,0 +1,210 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AdTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.FeatureMap;
import mindspore.schema.RequestUpdateModel;
import mindspore.schema.ResponseCode;
import mindspore.schema.ResponseUpdateModel;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ADBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
public class UpdateModel {
static {
System.loadLibrary("mindspore-lite-jni");
}
class RequestUpdateModelBuilder {
private RequestUpdateModel requestUM;
private FlatBufferBuilder builder;
private int fmOffset = 0;
private int nameOffset = 0;
private int idOffset = 0;
private int timestampOffset = 0;
private int iteration = 0;
private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT;
public RequestUpdateModelBuilder(EncryptLevel encryptLevel) {
builder = new FlatBufferBuilder();
this.encryptLevel = encryptLevel;
}
public RequestUpdateModelBuilder flName(String name) {
this.nameOffset = this.builder.createString(name);
return this;
}
public RequestUpdateModelBuilder time() {
Date date = new Date();
long time = date.getTime();
this.timestampOffset = builder.createString(String.valueOf(time));
return this;
}
public RequestUpdateModelBuilder iteration(int iteration) {
this.iteration = iteration;
return this;
}
public RequestUpdateModelBuilder id(String id) {
this.idOffset = this.builder.createString(id);
return this;
}
public RequestUpdateModelBuilder featuresMap(SecureProtocol secureProtocol, int trainDataSize) {
ArrayList<String> encryptFeatureName = secureProtocol.getEncryptFeatureName();
switch (encryptLevel) {
case PW_ENCRYPT:
try {
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize);
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
return this;
} catch (Exception e) {
LOGGER.severe("[Encrypt] catch error in maskModel: catch Exception");
}
case DP_ENCRYPT:
try {
int[] fmOffsetsPW = secureProtocol.dpMaskModel(builder, trainDataSize);
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
return this;
} catch (Exception e) {
LOGGER.severe(Common.addTag("[Encrypt] catch error in maskModel: " + e.getMessage()));
}
case NOT_ENCRYPT:
default:
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals(ADBERT)) {
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + flParameter.getFlName()));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
if (map.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil.convertTensorToFeatures>"));
status = FLClientStatus.FAILED;
}
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + flParameter.getFlName()));
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
if (map.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil.convertTensorToFeatures>"));
status = FLClientStatus.FAILED;
}
}
int featureSize = encryptFeatureName.size();
int[] fmOffsets = new int[featureSize];
for (int i = 0; i < featureSize; i++) {
String key = encryptFeatureName.get(i);
float[] data = map.get(key);
LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature size: " + data.length));
for (int j = 0; j < data.length; j++) {
float rawData = data[j];
data[j] = data[j] * trainDataSize;
}
int featureName = builder.createString(key);
int weight = FeatureMap.createDataVector(builder, data);
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
fmOffsets[i] = featureMap;
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsets);
return this;
}
}
public byte[] build() {
RequestUpdateModel.startRequestUpdateModel(this.builder);
RequestUpdateModel.addFlName(builder, nameOffset);
RequestUpdateModel.addFlId(this.builder, idOffset);
RequestUpdateModel.addTimestamp(builder, this.timestampOffset);
RequestUpdateModel.addIteration(builder, this.iteration);
RequestUpdateModel.addFeatureMap(builder, this.fmOffset);
int root = RequestUpdateModel.endRequestUpdateModel(builder);
builder.finish(root);
return builder.sizedByteArray();
}
}
private static final Logger LOGGER = Logger.getLogger(UpdateModel.class.toString());
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private String nextRequestTime;
private FLClientStatus status;
private static volatile UpdateModel updateModel;
private UpdateModel() {
}
public static synchronized UpdateModel getInstance() {
UpdateModel localRef = updateModel;
if (localRef == null) {
synchronized (UpdateModel.class) {
localRef = updateModel;
if (localRef == null) {
updateModel = localRef = new UpdateModel();
}
}
}
return localRef;
}
public String getNextRequestTime() {
return nextRequestTime;
}
public FLClientStatus getStatus() {
return status;
}
public byte[] getRequestUpdateFLJob(int iteration, SecureProtocol secureProtocol, int trainDataSize) {
RequestUpdateModelBuilder builder = new RequestUpdateModelBuilder(localFLParameter.getEncryptLevel());
return builder.flName(flParameter.getFlName()).time().id(localFLParameter.getFlID()).featuresMap(secureProtocol, trainDataSize).iteration(iteration).build();
}
public FLClientStatus doResponse(ResponseUpdateModel response) {
LOGGER.info(Common.addTag("[updateModel] ==========updateModel response================"));
LOGGER.info(Common.addTag("[updateModel] ==========retcode: " + response.retcode()));
LOGGER.info(Common.addTag("[updateModel] ==========reason: " + response.reason()));
LOGGER.info(Common.addTag("[updateModel] ==========next request time: " + response.nextReqTime()));
nextRequestTime = response.nextReqTime();
switch (response.retcode()) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[updateModel] updateModel success"));
return FLClientStatus.SUCCESS;
case (ResponseCode.OutOfTime):
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
case (ResponseCode.SystemError):
LOGGER.warning(Common.addTag("[updateModel] catch RequestError or SystemError"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[updateModel]the return <retcode> from server is invalid: " + response.retcode()));
return FLClientStatus.FAILED;
}
}
}

View File

@ -45,6 +45,7 @@ public class ClientListReq {
private String nextRequestTime;
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private int retCode;
public ClientListReq() {
flCommunication = FLCommunication.getInstance();
@ -58,6 +59,10 @@ public class ClientListReq {
this.nextRequestTime = nextRequestTime;
}
public int getRetCode() {
return retCode;
}
public FLClientStatus getClientList(int iteration, List<String> u3ClientList, List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[PairWiseMask] ==============getClientList url: " + url + "=============="));
@ -81,15 +86,15 @@ public class ClientListReq {
}
public FLClientStatus judgeGetClientList(ReturnClientList bufData, List<String> u3ClientList, List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
int retcode = bufData.retcode();
retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] ************** the response of GetClientList **************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
LOGGER.info(Common.addTag("[PairWiseMask] the size of clients: " + bufData.clientsLength()));
FLClientStatus status;
switch (retcode) {
switch (retCode) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[PairWiseMask] GetClientList success"));
u3ClientList.clear();
@ -117,7 +122,7 @@ public class ClientListReq {
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetClientList"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReturnClientList is invalid: " + retcode));
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnClientList is invalid: " + retCode));
return FLClientStatus.FAILED;
}
}

View File

@ -37,6 +37,7 @@ public class ReconstructSecretReq {
private String nextRequestTime;
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private int retCode;
public String getNextRequestTime() {
return nextRequestTime;
@ -46,6 +47,10 @@ public class ReconstructSecretReq {
this.nextRequestTime = nextRequestTime;
}
public int getRetCode() {
return retCode;
}
public ReconstructSecretReq() {
flCommunication = FLCommunication.getInstance();
}
@ -99,13 +104,13 @@ public class ReconstructSecretReq {
}
public FLClientStatus judgeSendReconstructSecrets(mindspore.schema.ReconstructSecret bufData) {
int retcode = bufData.retcode();
retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of SendReconstructSecrets**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
switch (retcode) {
switch (retCode) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[PairWiseMask] ReconstructSecrets success"));
return FLClientStatus.SUCCESS;
@ -118,7 +123,7 @@ public class ReconstructSecretReq {
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in SendReconstructSecrets"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReconstructSecret is invalid: " + retcode));
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReconstructSecret is invalid: " + retCode));
return FLClientStatus.FAILED;
}
}

View File

@ -26,7 +26,7 @@ public class AdTrainBert extends AdBert {
private static volatile AdTrainBert adTrainBert;
public static AdTrainBert getInstance() {
AdTrainBert localRef = adInferBert;
AdTrainBert localRef = adTrainBert;
if (localRef == null) {
synchronized (AdTrainBert.class) {
localRef = adTrainBert;