forked from mindspore-Ecosystem/mindspore
!18797 add code of interactive rounds of startFLJob/updateModel/getModel for fl_client
Merge pull request !18797 from zhoushan33/flclient0624_om
This commit is contained in:
commit
6e6fb37498
|
@ -83,6 +83,7 @@ public class CipherClient {
|
|||
private Random random = new Random();
|
||||
private ClientListReq clientListReq = new ClientListReq();
|
||||
private ReconstructSecretReq reconstructSecretReq = new ReconstructSecretReq();
|
||||
private int retCode;
|
||||
|
||||
public CipherClient(int iter, int minSecretNum, byte[] prime, int featureSize) {
|
||||
flCommunication = FLCommunication.getInstance();
|
||||
|
@ -110,6 +111,10 @@ public class CipherClient {
|
|||
return nextRequestTime;
|
||||
}
|
||||
|
||||
public int getRetCode() {
|
||||
return retCode;
|
||||
}
|
||||
|
||||
public void genDHKeyPairs() {
|
||||
byte[] csk = keyAgreement.generatePrivateKey();
|
||||
byte[] cpk = keyAgreement.generatePublicKey(csk);
|
||||
|
@ -282,13 +287,13 @@ public class CipherClient {
|
|||
}
|
||||
|
||||
public FLClientStatus judgeRequestExchangeKeys(ResponseExchangeKeys bufData) {
|
||||
int retcode = bufData.retcode();
|
||||
retCode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestExchangeKeys**************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
|
||||
switch (retcode) {
|
||||
switch (retCode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys success"));
|
||||
return FLClientStatus.SUCCESS;
|
||||
|
@ -301,7 +306,7 @@ public class CipherClient {
|
|||
LOGGER.info(Common.addTag("[PairWiseMask] catch RequestError or SystemError in RequestExchangeKeys"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ResponseExchangeKeys is invalid: " + retcode));
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseExchangeKeys is invalid: " + retCode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -329,12 +334,12 @@ public class CipherClient {
|
|||
}
|
||||
|
||||
public FLClientStatus judgeGetExchangeKeys(ReturnExchangeKeys bufData) {
|
||||
int retcode = bufData.retcode();
|
||||
retCode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetExchangeKeys**************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
|
||||
switch (retcode) {
|
||||
switch (retCode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] GetExchangeKeys success"));
|
||||
clientPublicKeyList.clear();
|
||||
|
@ -366,7 +371,7 @@ public class CipherClient {
|
|||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetExchangeKeys"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReturnExchangeKeys is invalid: " + retcode));
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnExchangeKeys is invalid: " + retCode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -417,13 +422,13 @@ public class CipherClient {
|
|||
}
|
||||
|
||||
public FLClientStatus judgeRequestShareSecrets(ResponseShareSecrets bufData) {
|
||||
int retcode = bufData.retcode();
|
||||
retCode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestShareSecrets**************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
|
||||
switch (retcode) {
|
||||
switch (retCode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets success"));
|
||||
return FLClientStatus.SUCCESS;
|
||||
|
@ -436,7 +441,7 @@ public class CipherClient {
|
|||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in RequestShareSecrets"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ResponseShareSecrets is invalid: " + retcode));
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseShareSecrets is invalid: " + retCode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -464,13 +469,13 @@ public class CipherClient {
|
|||
}
|
||||
|
||||
public FLClientStatus judgeGetShareSecrets(ReturnShareSecrets bufData) {
|
||||
int retcode = bufData.retcode();
|
||||
retCode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetShareSecrets**************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] the size of encrypted shares: " + bufData.encryptedSharesLength()));
|
||||
switch (retcode) {
|
||||
switch (retCode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] GetShareSecrets success"));
|
||||
returnShareList.clear();
|
||||
|
@ -499,7 +504,7 @@ public class CipherClient {
|
|||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetShareSecrets"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReturnShareSecrets is invalid: " + retcode));
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnShareSecrets is invalid: " + retCode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
@ -563,6 +568,7 @@ public class CipherClient {
|
|||
if (curStatus != FLClientStatus.SUCCESS) {
|
||||
return curStatus;
|
||||
}
|
||||
retCode = clientListReq.getRetCode();
|
||||
|
||||
// SendReconstructSecret
|
||||
curStatus = reconstructSecretReq.sendReconstructSecret(decryptShareSecretsList, u3ClientList, iteration);
|
||||
|
@ -573,6 +579,7 @@ public class CipherClient {
|
|||
if (curStatus == FLClientStatus.RESTART) {
|
||||
nextRequestTime = reconstructSecretReq.getNextRequestTime();
|
||||
}
|
||||
retCode = reconstructSecretReq.getRetCode();
|
||||
return curStatus;
|
||||
}
|
||||
}
|
|
@ -88,7 +88,7 @@ public class Common {
|
|||
Date date = new Date();
|
||||
long currentTime = date.getTime();
|
||||
long waitTime = 0;
|
||||
if (!nextRequestTime.equals("")) {
|
||||
if (!("").equals(nextRequestTime)) {
|
||||
waitTime = Math.max(0, Long.valueOf(nextRequestTime) - currentTime);
|
||||
}
|
||||
LOGGER.info(addTag("[getWaitTime] next request time stamp: " + nextRequestTime + " current time stamp: " + currentTime));
|
||||
|
@ -114,12 +114,20 @@ public class Common {
|
|||
return LOG_TITLE + message;
|
||||
}
|
||||
|
||||
public static boolean isAutoscaling(byte[] message, String autoscalingTag) {
|
||||
return (new String(message)).contains(autoscalingTag);
|
||||
public static boolean isSafeMod(byte[] message, String safeModTag) {
|
||||
return (new String(message)).contains(safeModTag);
|
||||
}
|
||||
|
||||
public static boolean checkPath(String path) {
|
||||
File file = new File(path);
|
||||
return file.exists();
|
||||
boolean tag = true;
|
||||
String [] paths = path.split(",");
|
||||
for (int i = 0; i < paths.length; i++) {
|
||||
LOGGER.info(addTag("[check path]:" + paths[i]));
|
||||
File file = new File(paths[i]);
|
||||
if (!file.exists()) {
|
||||
tag = false;
|
||||
}
|
||||
}
|
||||
return tag;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ public class FLCommunication implements IFLCommunication {
|
|||
private static final Logger LOGGER = Logger.getLogger(FLCommunication.class.toString());
|
||||
private OkHttpClient client;
|
||||
|
||||
private static FLCommunication communication;
|
||||
private static volatile FLCommunication communication;
|
||||
|
||||
private FLCommunication() {
|
||||
if (flParameter.getTimeOut() != 0) {
|
||||
|
@ -109,14 +109,16 @@ public class FLCommunication implements IFLCommunication {
|
|||
}
|
||||
|
||||
public static FLCommunication getInstance() {
|
||||
if (communication == null) {
|
||||
FLCommunication localRef = communication;
|
||||
if (localRef == null) {
|
||||
synchronized (FLCommunication.class) {
|
||||
if (communication == null) {
|
||||
communication = new FLCommunication();
|
||||
localRef = communication;
|
||||
if (localRef == null) {
|
||||
communication = localRef = new FLCommunication();
|
||||
}
|
||||
}
|
||||
}
|
||||
return communication;
|
||||
return localRef;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -0,0 +1,535 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import com.mindspore.flclient.cipher.BaseUtil;
|
||||
import com.mindspore.flclient.model.AdInferBert;
|
||||
import com.mindspore.flclient.model.AdTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import mindspore.schema.CipherPublicParams;
|
||||
import mindspore.schema.FLPlan;
|
||||
import mindspore.schema.ResponseCode;
|
||||
import mindspore.schema.ResponseFLJob;
|
||||
import mindspore.schema.ResponseGetModel;
|
||||
import mindspore.schema.ResponseUpdateModel;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Date;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.TreeMap;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||
import static com.mindspore.flclient.LocalFLParameter.ADBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
public class FLLiteClient {
|
||||
private static final Logger LOGGER = Logger.getLogger(FLLiteClient.class.toString());
|
||||
private FLCommunication flCommunication;
|
||||
|
||||
private FLClientStatus status;
|
||||
private int retCode;
|
||||
|
||||
private static int iteration = 0;
|
||||
private int iterations = 1;
|
||||
private int epochs = 1;
|
||||
private int batchSize = 16;
|
||||
private int minSecretNum;
|
||||
private byte[] prime;
|
||||
private int featureSize;
|
||||
private int trainDataSize;
|
||||
private double dpEps = 100;
|
||||
private double dpDelta = 0.01;
|
||||
private double dpNormClip = 2.0;
|
||||
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private SecureProtocol secureProtocol = new SecureProtocol();
|
||||
private static Map<String, float[]> mapBeforeTrain;
|
||||
private String nextRequestTime;
|
||||
|
||||
public FLLiteClient() {
|
||||
flCommunication = FLCommunication.getInstance();
|
||||
}
|
||||
|
||||
public int setGlobalParameters(ResponseFLJob flJob) {
|
||||
FLPlan flPlan = flJob.flPlanConfig();
|
||||
if (flPlan == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the FLPlan get from server is null"));
|
||||
return -1;
|
||||
}
|
||||
iterations = flPlan.iterations();
|
||||
epochs = flPlan.epochs();
|
||||
batchSize = flPlan.miniBatch();
|
||||
String serverMod = flPlan.serverMode();
|
||||
localFLParameter.setServerMod(serverMod);
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <serverMod> from server: " + serverMod));
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for AdTrainBert: " + batchSize));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
adTrainBert.setBatchSize(batchSize);
|
||||
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for TrainLenet: " + batchSize));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
trainLenet.setBatchSize(batchSize);
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <iterations> from server: " + iterations));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <epochs> from server: " + epochs));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <batchSize> from server: " + batchSize));
|
||||
CipherPublicParams cipherPublicParams = flPlan.cipher();
|
||||
String encryptLevel = EncryptLevel.NOT_ENCRYPT.toString();
|
||||
localFLParameter.setEncryptLevel(encryptLevel);
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <encryptLevel> from server: " + encryptLevel));
|
||||
switch (localFLParameter.getEncryptLevel()) {
|
||||
case PW_ENCRYPT:
|
||||
minSecretNum = cipherPublicParams.t();
|
||||
int primeLength = cipherPublicParams.primeLength();
|
||||
prime = new byte[primeLength];
|
||||
for (int i = 0; i < primeLength; i++) {
|
||||
prime[i] = (byte) cipherPublicParams.prime(i);
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server: " + minSecretNum));
|
||||
LOGGER.info(Common.addTag("[Encrypt] the prime from server: " + BaseUtil.byte2HexString(prime)));
|
||||
case DP_ENCRYPT:
|
||||
dpEps = cipherPublicParams.dpEps();
|
||||
dpDelta = cipherPublicParams.dpDelta();
|
||||
dpNormClip = cipherPublicParams.dpNormClip();
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpEps> from server: " + dpEps));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpDelta> from server: " + dpDelta));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpNormClip> from server: " + dpNormClip));
|
||||
break;
|
||||
default:
|
||||
LOGGER.info(Common.addTag("[startFLJob] NotEncrypt, do not set parameter for Encrypt"));
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
public int getRetCode() {
|
||||
return retCode;
|
||||
}
|
||||
|
||||
public int getIteration() {
|
||||
return iteration;
|
||||
}
|
||||
|
||||
public int getIterations() {
|
||||
return iterations;
|
||||
}
|
||||
|
||||
public int getEpochs() {
|
||||
return epochs;
|
||||
}
|
||||
|
||||
public int getBatchSize() {
|
||||
return batchSize;
|
||||
}
|
||||
|
||||
public String getNextRequestTime() {
|
||||
return nextRequestTime;
|
||||
}
|
||||
|
||||
public void setNextRequestTime(String nextRequestTime) {
|
||||
this.nextRequestTime = nextRequestTime;
|
||||
}
|
||||
|
||||
public void setTrainDataSize(int trainDataSize) {
|
||||
this.trainDataSize = trainDataSize;
|
||||
}
|
||||
|
||||
public FLClientStatus checkStatus() {
|
||||
return this.status;
|
||||
}
|
||||
|
||||
public FLClientStatus startFLJob() {
|
||||
LOGGER.info(Common.addTag("[startFLJob] ====================================Verify server===================================="));
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
LOGGER.info(Common.addTag("[startFLJob] ==============startFLJob url: " + url + "=============="));
|
||||
StartFLJob startFLJob = StartFLJob.getInstance();
|
||||
Date date = new Date();
|
||||
long time = date.getTime();
|
||||
byte[] msg = startFLJob.getRequestStartFLJob(trainDataSize, iteration, time);
|
||||
try {
|
||||
long start = Common.startTime("single startFLJob");
|
||||
LOGGER.info(Common.addTag("[startFLJob] the request message length: " + msg.length));
|
||||
byte[] message = flCommunication.syncRequest(url + "/startFLJob", msg);
|
||||
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] The cluster is in safemode, need wait some time and request again"));
|
||||
status = FLClientStatus.RESTART;
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] the response message length: " + message.length));
|
||||
Common.endTime(start, "single startFLJob");
|
||||
ByteBuffer buffer = ByteBuffer.wrap(message);
|
||||
ResponseFLJob responseDataBuf = ResponseFLJob.getRootAsResponseFLJob(buffer);
|
||||
status = judgeStartFLJob(startFLJob, responseDataBuf);
|
||||
retCode = responseDataBuf.retcode();
|
||||
} catch (IOException e) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in StartFLJob: catch IOException: " + e.getMessage()));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
public FLClientStatus judgeStartFLJob(StartFLJob startFLJob, ResponseFLJob responseDataBuf) {
|
||||
iteration = responseDataBuf.iteration();
|
||||
FLClientStatus response = startFLJob.doResponse(responseDataBuf);
|
||||
status = response;
|
||||
switch (response) {
|
||||
case SUCCESS:
|
||||
LOGGER.info(Common.addTag("[startFLJob] startFLJob success"));
|
||||
featureSize = startFLJob.getFeatureSize();
|
||||
secureProtocol.setEncryptFeatureName(startFLJob.getEncryptFeatureName());
|
||||
LOGGER.info(Common.addTag("[startFLJob] ***the feature size get in ResponseFLJob***: " + featureSize));
|
||||
int tag = setGlobalParameters(responseDataBuf);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] setGlobalParameters failed"));
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
break;
|
||||
case RESTART:
|
||||
FLPlan flPlan = responseDataBuf.flPlanConfig();
|
||||
iterations = flPlan.iterations();
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <iterations> from server: " + iterations));
|
||||
nextRequestTime = responseDataBuf.nextReqTime();
|
||||
break;
|
||||
case FAILED:
|
||||
LOGGER.severe(Common.addTag("[startFLJob] startFLJob failed"));
|
||||
break;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[startFLJob] failed: the response of startFLJob is out of range <SUCCESS, WAIT, FAILED, Restart>"));
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
public FLClientStatus localTrain() {
|
||||
LOGGER.info(Common.addTag("[train] ====================================global train epoch " + iteration + "===================================="));
|
||||
status = FLClientStatus.SUCCESS;
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
LOGGER.info(Common.addTag("[train] train in adbert"));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
int tag = adTrainBert.trainModel(flParameter.getTrainModelPath(), epochs);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[train] unsolved error code in <adTrainBert.trainModel>"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
}
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("[train] train in lenet"));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
int tag = trainLenet.trainModel(flParameter.getTrainModelPath(), epochs);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[train] unsolved error code in <trainLenet.trainModel>"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
public FLClientStatus updateModel() {
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
LOGGER.info(Common.addTag("[updateModel] ==============updateModel url: " + url + "=============="));
|
||||
UpdateModel updateModelBuf = UpdateModel.getInstance();
|
||||
byte[] updateModelBuffer = updateModelBuf.getRequestUpdateFLJob(iteration, secureProtocol, trainDataSize);
|
||||
if (updateModelBuf.getStatus() == FLClientStatus.FAILED) {
|
||||
LOGGER.info(Common.addTag("[updateModel] catch error in build RequestUpdateFLJob"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
try {
|
||||
long start = Common.startTime("single updateModel");
|
||||
LOGGER.info(Common.addTag("[updateModel] the request message length: " + updateModelBuffer.length));
|
||||
byte[] message = flCommunication.syncRequest(url + "/updateModel", updateModelBuffer);
|
||||
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) {
|
||||
LOGGER.info(Common.addTag("[updateModel] The cluster is in safemode, need wait some time and request again"));
|
||||
status = FLClientStatus.RESTART;
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[updateModel] the response message length: " + message.length));
|
||||
Common.endTime(start, "single updateModel");
|
||||
ByteBuffer debugBuffer = ByteBuffer.wrap(message);
|
||||
ResponseUpdateModel responseDataBuf = ResponseUpdateModel.getRootAsResponseUpdateModel(debugBuffer);
|
||||
status = updateModelBuf.doResponse(responseDataBuf);
|
||||
retCode = responseDataBuf.retcode();
|
||||
if (status == FLClientStatus.RESTART) {
|
||||
nextRequestTime = responseDataBuf.nextReqTime();
|
||||
}
|
||||
LOGGER.info(Common.addTag("[updateModel] get response from server ok!"));
|
||||
} catch (IOException e) {
|
||||
LOGGER.severe(Common.addTag("[updateModel] unsolved error code in updateModel: catch IOException: " + e.getMessage()));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
public FLClientStatus getModel() {
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
LOGGER.info(Common.addTag("[getModel] ===========getModel url: " + url + "=============="));
|
||||
GetModel getModelBuf = GetModel.getInstance();
|
||||
byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), iteration);
|
||||
try {
|
||||
long start = Common.startTime("single getModel");
|
||||
LOGGER.info(Common.addTag("[getModel] the request message length: " + buffer.length));
|
||||
byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
|
||||
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) {
|
||||
LOGGER.info(Common.addTag("[getModel] The cluster is in safemode, need wait some time and request again"));
|
||||
status = FLClientStatus.WAIT;
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] the response message length: " + message.length));
|
||||
Common.endTime(start, "single getModel");
|
||||
LOGGER.info(Common.addTag("[getModel] get model request success"));
|
||||
ByteBuffer debugBuffer = ByteBuffer.wrap(message);
|
||||
ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer);
|
||||
status = getModelBuf.doResponse(responseDataBuf);
|
||||
retCode = responseDataBuf.retcode();
|
||||
if (status == FLClientStatus.RESTART) {
|
||||
nextRequestTime = responseDataBuf.timestamp();
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] get response from server ok!"));
|
||||
} catch (IOException e) {
|
||||
LOGGER.severe(Common.addTag("[getModel] un sloved error code: catch IOException: " + e.getMessage()));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
public static synchronized Map<String, float[]> getOldMapCopy(Map<String, float[]> map) {
|
||||
if (mapBeforeTrain == null) {
|
||||
Map<String, float[]> copyMap = new TreeMap<>();
|
||||
for (String key : map.keySet()) {
|
||||
float[] data = map.get(key);
|
||||
int dataLen = data.length;
|
||||
float[] weights = new float[dataLen];
|
||||
if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) && (key.indexOf("learning") < 0)) {
|
||||
for (int j = 0; j < dataLen; j++) {
|
||||
float weight = data[j];
|
||||
weights[j] = weight;
|
||||
}
|
||||
copyMap.put(key, weights);
|
||||
}
|
||||
}
|
||||
mapBeforeTrain = copyMap;
|
||||
} else {
|
||||
for (String key : map.keySet()) {
|
||||
float[] data = map.get(key);
|
||||
float[] copyData = mapBeforeTrain.get(key);
|
||||
int dataLen = data.length;
|
||||
if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) && (key.indexOf("learning") < 0)) {
|
||||
for (int j = 0; j < dataLen; j++) {
|
||||
copyData[j] = data[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return mapBeforeTrain;
|
||||
}
|
||||
|
||||
public FLClientStatus getFeatureMask() {
|
||||
FLClientStatus curStatus;
|
||||
switch (localFLParameter.getEncryptLevel()) {
|
||||
case PW_ENCRYPT:
|
||||
LOGGER.info(Common.addTag("[Encrypt] creating feature mask of <" + localFLParameter.getEncryptLevel().toString() + ">"));
|
||||
secureProtocol.setPWParameter(iteration, minSecretNum, prime, featureSize);
|
||||
curStatus = secureProtocol.pwCreateMask();
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
nextRequestTime = secureProtocol.getNextRequestTime();
|
||||
}
|
||||
retCode = secureProtocol.getRetCode();
|
||||
LOGGER.info(Common.addTag("[Encrypt] the response of create mask for <" + localFLParameter.getEncryptLevel().toString() + "> : " + curStatus));
|
||||
return curStatus;
|
||||
case DP_ENCRYPT:
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
}
|
||||
Map<String, float[]> copyMap = getOldMapCopy(map);
|
||||
curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClip, copyMap);
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
if (curStatus != FLClientStatus.SUCCESS) {
|
||||
LOGGER.info(Common.addTag("---Differential privacy init failed---"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[Encrypt] set parameters for DPEncrypt!"));
|
||||
return FLClientStatus.SUCCESS;
|
||||
case NOT_ENCRYPT:
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
LOGGER.info(Common.addTag("[Encrypt] don't mask model"));
|
||||
return FLClientStatus.SUCCESS;
|
||||
default:
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
LOGGER.severe(Common.addTag("[Encrypt] The encrypt level is error, not encrypt by default"));
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus unMasking() {
|
||||
FLClientStatus curStatus;
|
||||
switch (localFLParameter.getEncryptLevel()) {
|
||||
case PW_ENCRYPT:
|
||||
curStatus = secureProtocol.pwUnmasking();
|
||||
retCode = secureProtocol.getRetCode();
|
||||
LOGGER.info(Common.addTag("[Encrypt] the response of unmasking : " + curStatus));
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
nextRequestTime = secureProtocol.getNextRequestTime();
|
||||
}
|
||||
return curStatus;
|
||||
case DP_ENCRYPT:
|
||||
LOGGER.info(Common.addTag("[Encrypt] DPEncrypt do not need unmasking"));
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
return FLClientStatus.SUCCESS;
|
||||
case NOT_ENCRYPT:
|
||||
LOGGER.info(Common.addTag("[Encrypt] haven't mask model"));
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
return FLClientStatus.SUCCESS;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[Encrypt] The encrypt level is error, not encrypt by default"));
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus evaluateModel() {
|
||||
status = FLClientStatus.SUCCESS;
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
LOGGER.info(Common.addTag("===================================test combine model from server==================================="));
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
AdInferBert adInferBert = AdInferBert.getInstance();
|
||||
int dataSize = adInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile(), true);
|
||||
if (dataSize <= 0) {
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <adTrainBert.initDataSet>: the return dataSize<=0"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
float acc = adInferBert.evalModel();
|
||||
if (acc == Float.NaN) {
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <adTrainBert.evalModel>: the return acc is NAN"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " + flParameter.getTestDataset() + " vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile()));
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
int dataSize = trainLenet.initDataSet(flParameter.getTestDataset().split(",")[0], flParameter.getTestDataset().split(",")[1]);
|
||||
if (dataSize <= 0) {
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.initDataSet>: the return dataSize<=0"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
float acc = trainLenet.evalModel();
|
||||
if (acc == Float.NaN) {
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.evalModel>: the return acc is NAN"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " + flParameter.getTestDataset().split(",")[0] + " labelPath: " + flParameter.getTestDataset().split(",")[1]));
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param dataPath, train or test dataset and label set
|
||||
*/
|
||||
public int setInput(String dataPath) {
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
LOGGER.info(Common.addTag("==========set input==========="));
|
||||
int dataSize = 0;
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
dataSize = adTrainBert.initDataSet(dataPath, flParameter.getVocabFile(), flParameter.getIdsFile());
|
||||
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath + " dataSize: " + +dataSize + " vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile()));
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
dataSize = trainLenet.initDataSet(dataPath.split(",")[0], dataPath.split(",")[1]);
|
||||
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath.split(",")[0] + " dataSize: " + +dataSize + " labelPath: " + dataPath.split(",")[1]));
|
||||
}
|
||||
if (dataSize <= 0) {
|
||||
retCode = ResponseCode.RequestError;
|
||||
return -1;
|
||||
}
|
||||
return dataSize;
|
||||
}
|
||||
|
||||
public FLClientStatus initSession() {
|
||||
int tag = 0;
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
tag = adTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session============="));
|
||||
AdInferBert adInferBert = AdInferBert.getInstance();
|
||||
tag = adInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||
}
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finalize() {
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
LOGGER.info(Common.addTag("===========free train session============="));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
SessionUtil.free(adTrainBert.getTrainSession());
|
||||
if (!flParameter.getTestDataset().equals("null")) {
|
||||
LOGGER.info(Common.addTag("===========free inference session============="));
|
||||
AdInferBert adInferBert = AdInferBert.getInstance();
|
||||
SessionUtil.free(adInferBert.getTrainSession());
|
||||
}
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("===========free session============="));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
SessionUtil.free(trainLenet.getTrainSession());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,307 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import java.util.logging.Logger;
|
||||
|
||||
public class FLParameter {
|
||||
private static final Logger LOGGER = Logger.getLogger(FLParameter.class.toString());
|
||||
|
||||
public static final int TIME_OUT = 100;
|
||||
public static final int SLEEP_TIME = 1000;
|
||||
private String hostName;
|
||||
private String certPath;
|
||||
|
||||
private String trainDataset;
|
||||
private String vocabFile = "null";
|
||||
private String idsFile = "null";
|
||||
private String testDataset = "null";
|
||||
private String flName;
|
||||
private String trainModelPath;
|
||||
private String inferModelPath;
|
||||
private String clientID;
|
||||
private String ip;
|
||||
private int port;
|
||||
private boolean useSSL = false;
|
||||
private int timeOut;
|
||||
private int sleepTime;
|
||||
private boolean useElb = false;
|
||||
private int serverNum = 1;
|
||||
|
||||
private boolean timer = true;
|
||||
private int timeWindow = 6000;
|
||||
private int reRequestNum = timeWindow / SLEEP_TIME + 1;
|
||||
|
||||
private static volatile FLParameter flParameter;
|
||||
|
||||
public static FLParameter getInstance() {
|
||||
FLParameter localRef = flParameter;
|
||||
if (localRef == null) {
|
||||
synchronized (FLParameter.class) {
|
||||
localRef = flParameter;
|
||||
if (localRef == null) {
|
||||
flParameter = localRef = new FLParameter();
|
||||
}
|
||||
}
|
||||
}
|
||||
return localRef;
|
||||
}
|
||||
|
||||
public String getHostName() {
|
||||
if ("".equals(hostName) || hostName.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <hostName> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
return hostName;
|
||||
}
|
||||
|
||||
public void setHostName(String hostName) {
|
||||
this.hostName = hostName;
|
||||
}
|
||||
|
||||
public String getCertPath() {
|
||||
if ("".equals(certPath) || certPath.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
return certPath;
|
||||
}
|
||||
|
||||
public void setCertPath(String certPath) {
|
||||
this.certPath = certPath;
|
||||
}
|
||||
|
||||
public String getTrainDataset() {
|
||||
if ("".equals(trainDataset) || trainDataset.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
return trainDataset;
|
||||
}
|
||||
|
||||
public void setTrainDataset(String trainDataset) {
|
||||
if (Common.checkPath(trainDataset)) {
|
||||
this.trainDataset = trainDataset;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is not exist, please check it before set"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
}
|
||||
|
||||
public String getVocabFile() {
|
||||
if ("null".equals(vocabFile) && "adbert".equals(flName)) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
return vocabFile;
|
||||
}
|
||||
|
||||
public void setVocabFile(String vocabFile) {
|
||||
if (Common.checkPath(vocabFile)) {
|
||||
this.vocabFile = vocabFile;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is not exist, please check it before set"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
}
|
||||
|
||||
public String getIdsFile() {
|
||||
if ("null".equals(idsFile) && "adbert".equals(flName)) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
return idsFile;
|
||||
}
|
||||
|
||||
public void setIdsFile(String idsFile) {
|
||||
if (Common.checkPath(idsFile)) {
|
||||
this.idsFile = idsFile;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is not exist, please check it before set"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
}
|
||||
|
||||
public String getTestDataset() {
|
||||
return testDataset;
|
||||
}
|
||||
|
||||
public void setTestDataset(String testDataset) {
|
||||
if (Common.checkPath(testDataset)) {
|
||||
this.testDataset = testDataset;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <testDataset> is not exist, please check it before set"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
}
|
||||
|
||||
public String getFlName() {
|
||||
if ("".equals(flName) || flName.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
return flName;
|
||||
}
|
||||
|
||||
public void setFlName(String flName) {
|
||||
if (Common.checkFLName(flName)) {
|
||||
this.flName = flName;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is not in flNameTrustList, please check it before set"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
}
|
||||
|
||||
public String getTrainModelPath() {
|
||||
if ("".equals(trainModelPath) || trainModelPath.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
return trainModelPath;
|
||||
}
|
||||
|
||||
public void setTrainModelPath(String trainModelPath) {
|
||||
if (Common.checkPath(trainModelPath)) {
|
||||
this.trainModelPath = trainModelPath;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is not exist, please check it before set"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
}
|
||||
|
||||
public String getInferModelPath() {
|
||||
if ("".equals(inferModelPath) || inferModelPath.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
return inferModelPath;
|
||||
}
|
||||
|
||||
public void setInferModelPath(String inferModelPath) {
|
||||
if (Common.checkPath(inferModelPath)) {
|
||||
this.inferModelPath = inferModelPath;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is not exist, please check it before set"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
}
|
||||
|
||||
public String getIp() {
|
||||
if ("".equals(ip) || ip.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <ip> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
return ip;
|
||||
}
|
||||
|
||||
public void setIp(String ip) {
|
||||
this.ip = ip;
|
||||
}
|
||||
|
||||
public boolean isUseSSL() {
|
||||
return useSSL;
|
||||
}
|
||||
|
||||
public void setUseSSL(boolean useSSL) {
|
||||
this.useSSL = useSSL;
|
||||
}
|
||||
|
||||
public int getPort() {
|
||||
if (port == 0) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <port> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
return port;
|
||||
}
|
||||
|
||||
public void setPort(int port) {
|
||||
this.port = port;
|
||||
}
|
||||
|
||||
|
||||
public int getTimeOut() {
|
||||
return timeOut;
|
||||
}
|
||||
|
||||
public void setTimeOut(int timeOut) {
|
||||
this.timeOut = timeOut;
|
||||
}
|
||||
|
||||
public int getSleepTime() {
|
||||
return sleepTime;
|
||||
}
|
||||
|
||||
public void setSleepTime(int sleepTime) {
|
||||
this.sleepTime = sleepTime;
|
||||
}
|
||||
|
||||
public boolean isUseElb() {
|
||||
return useElb;
|
||||
}
|
||||
|
||||
public void setUseElb(boolean useElb) {
|
||||
this.useElb = useElb;
|
||||
}
|
||||
|
||||
public int getServerNum() {
|
||||
if (serverNum == 0) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <serverNum> is zero, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
return serverNum;
|
||||
}
|
||||
|
||||
public void setServerNum(int serverNum) {
|
||||
this.serverNum = serverNum;
|
||||
}
|
||||
|
||||
public boolean isTimer() {
|
||||
return timer;
|
||||
}
|
||||
|
||||
public void setTimer(boolean timer) {
|
||||
this.timer = timer;
|
||||
}
|
||||
|
||||
public int getTimeWindow() {
|
||||
return timeWindow;
|
||||
}
|
||||
|
||||
public void setTimeWindow(int timeWindow) {
|
||||
this.timeWindow = timeWindow;
|
||||
}
|
||||
|
||||
public int getReRequestNum() {
|
||||
return reRequestNum;
|
||||
}
|
||||
|
||||
public void setReRequestNum(int reRequestNum) {
|
||||
this.reRequestNum = reRequestNum;
|
||||
}
|
||||
|
||||
public String getClientID() {
|
||||
if ("".equals(clientID) || clientID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <clientID> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
return clientID;
|
||||
}
|
||||
|
||||
public void setClientID(String clientID) {
|
||||
this.clientID = clientID;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,184 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import com.mindspore.flclient.model.AdInferBert;
|
||||
import com.mindspore.flclient.model.AdTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import mindspore.schema.FeatureMap;
|
||||
import mindspore.schema.RequestGetModel;
|
||||
import mindspore.schema.ResponseCode;
|
||||
import mindspore.schema.ResponseGetModel;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Date;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
public class GetModel {
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-jni");
|
||||
}
|
||||
|
||||
class RequestGetModelBuilder {
|
||||
private FlatBufferBuilder builder;
|
||||
private int nameOffset = 0;
|
||||
private int iteration = 0;
|
||||
private int timeStampOffset = 0;
|
||||
|
||||
public RequestGetModelBuilder() {
|
||||
builder = new FlatBufferBuilder();
|
||||
}
|
||||
|
||||
public RequestGetModelBuilder flName(String name) {
|
||||
this.nameOffset = this.builder.createString(name);
|
||||
return this;
|
||||
}
|
||||
|
||||
public RequestGetModelBuilder time() {
|
||||
Date date = new Date();
|
||||
long time = date.getTime();
|
||||
this.timeStampOffset = builder.createString(String.valueOf(time));
|
||||
return this;
|
||||
}
|
||||
|
||||
public RequestGetModelBuilder iteration(int iteration) {
|
||||
this.iteration = iteration;
|
||||
return this;
|
||||
}
|
||||
|
||||
public byte[] build() {
|
||||
int root = RequestGetModel.createRequestGetModel(this.builder, nameOffset, iteration, timeStampOffset);
|
||||
builder.finish(root);
|
||||
return builder.sizedByteArray();
|
||||
}
|
||||
}
|
||||
|
||||
private static final Logger LOGGER = Logger.getLogger(GetModel.class.toString());
|
||||
private static GetModel getModel;
|
||||
|
||||
private GetModel() {
|
||||
}
|
||||
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
|
||||
public static GetModel getInstance() {
|
||||
if (getModel == null) {
|
||||
getModel = new GetModel();
|
||||
}
|
||||
return getModel;
|
||||
}
|
||||
|
||||
public byte[] getRequestGetModel(String name, int iteration) {
|
||||
RequestGetModelBuilder builder = new RequestGetModelBuilder();
|
||||
return builder.iteration(iteration).flName(name).time().build();
|
||||
}
|
||||
|
||||
|
||||
private FLClientStatus parseResponseAdbert(ResponseGetModel responseDataBuf) {
|
||||
int fmCount = responseDataBuf.featureMapLength();
|
||||
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
|
||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = responseDataBuf.featureMap(i);
|
||||
String featureName = feature.weightFullname();
|
||||
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
|
||||
albertFeatureMaps.add(feature);
|
||||
inferFeatureMaps.add(feature);
|
||||
} else if (localFLParameter.getClassifierWeightName().contains(featureName)) {
|
||||
inferFeatureMaps.add(feature);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference model-----------------"));
|
||||
AdInferBert adInferBert = AdInferBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(adInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into train model-----------------"));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(adTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
private FLClientStatus parseResponseLenet(ResponseGetModel responseDataBuf) {
|
||||
int fmCount = responseDataBuf.featureMapLength();
|
||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = responseDataBuf.featureMap(i);
|
||||
String featureName = feature.weightFullname();
|
||||
featureMaps.add(feature);
|
||||
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " + feature.dataLength()));
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
tag = SessionUtil.updateFeatures(trainLenet.getTrainSession(), flParameter.getTrainModelPath(), featureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
public FLClientStatus doResponse(ResponseGetModel responseDataBuf) {
|
||||
LOGGER.info(Common.addTag("[getModel] ==========get model content is:================"));
|
||||
LOGGER.info(Common.addTag("[getModel] ==========retCode: " + responseDataBuf.retcode()));
|
||||
LOGGER.info(Common.addTag("[getModel] ==========reason: " + responseDataBuf.reason()));
|
||||
LOGGER.info(Common.addTag("[getModel] ==========iteration: " + responseDataBuf.iteration()));
|
||||
LOGGER.info(Common.addTag("[getModel] ==========time: " + responseDataBuf.timestamp()));
|
||||
FLClientStatus status = FLClientStatus.SUCCESS;
|
||||
int retCode = responseDataBuf.retcode();
|
||||
switch (retCode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[getModel] getModel response success"));
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseAdbert>"));
|
||||
status = parseResponseAdbert(responseDataBuf);
|
||||
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>"));
|
||||
status = parseResponseLenet(responseDataBuf);
|
||||
}
|
||||
return status;
|
||||
case (ResponseCode.SucNotReady):
|
||||
LOGGER.info(Common.addTag("[getModel] server is not ready now: need wait and request getModel again"));
|
||||
return FLClientStatus.WAIT;
|
||||
case (ResponseCode.OutOfTime):
|
||||
LOGGER.info(Common.addTag("[getModel] out of time: need wait and request startFLJob again"));
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
case (ResponseCode.SystemError):
|
||||
LOGGER.warning(Common.addTag("[getModel] catch RequestError or SystemError"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[getModel] the return <retCode> from server is invalid: " + retCode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,125 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
public class LocalFLParameter {
|
||||
private static final Logger LOGGER = Logger.getLogger(LocalFLParameter.class.toString());
|
||||
public static final int SEED_SIZE = 32;
|
||||
public static final int IVEC_LEN = 16;
|
||||
public static final String LENET = "lenet";
|
||||
public static final String ADBERT = "adbert";
|
||||
private List<String> classifierWeightName = new ArrayList<>();
|
||||
private List<String> albertWeightName = new ArrayList<>();
|
||||
|
||||
private String flID;
|
||||
private String encryptLevel = "NotEncrypt";
|
||||
private String earlyStopMod = "NotEarlyStop";
|
||||
private String serverMod = ServerMod.HYBRID_TRAINING.toString();
|
||||
private String safeMod = "The cluster is in safemode.";
|
||||
|
||||
private static volatile LocalFLParameter localFLParameter;
|
||||
|
||||
private LocalFLParameter() {
|
||||
// set classifierWeightName albertWeightName
|
||||
Common.setClassifierWeightName(classifierWeightName);
|
||||
Common.setAlbertWeightName(albertWeightName);
|
||||
}
|
||||
|
||||
public static synchronized LocalFLParameter getInstance() {
|
||||
LocalFLParameter localRef = localFLParameter;
|
||||
if (localRef == null) {
|
||||
synchronized (LocalFLParameter.class) {
|
||||
localRef = localFLParameter;
|
||||
if (localRef == null) {
|
||||
localFLParameter = localRef = new LocalFLParameter();
|
||||
}
|
||||
}
|
||||
}
|
||||
return localRef;
|
||||
}
|
||||
|
||||
public List<String> getClassifierWeightName() {
|
||||
if (classifierWeightName.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
return classifierWeightName;
|
||||
}
|
||||
|
||||
public void setClassifierWeightName(List<String> classifierWeightName) {
|
||||
this.classifierWeightName = classifierWeightName;
|
||||
}
|
||||
|
||||
public List<String> getAlbertWeightName() {
|
||||
if (albertWeightName.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
return albertWeightName;
|
||||
}
|
||||
|
||||
public void setAlbertWeightName(List<String> albertWeightName) {
|
||||
this.albertWeightName = albertWeightName;
|
||||
}
|
||||
|
||||
public String getFlID() {
|
||||
if ("".equals(flID) || flID == null) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <flID> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
return flID;
|
||||
}
|
||||
|
||||
public void setFlID(String flID) {
|
||||
this.flID = flID;
|
||||
}
|
||||
|
||||
public EncryptLevel getEncryptLevel() {
|
||||
return EncryptLevel.valueOf(encryptLevel);
|
||||
}
|
||||
|
||||
public void setEncryptLevel(String encryptLevel) {
|
||||
this.encryptLevel = encryptLevel;
|
||||
}
|
||||
|
||||
public EarlyStopMod getEarlyStopMod() {
|
||||
return EarlyStopMod.valueOf(earlyStopMod);
|
||||
}
|
||||
|
||||
public void setEarlyStopMod(String earlyStopMod) {
|
||||
this.earlyStopMod = earlyStopMod;
|
||||
}
|
||||
|
||||
public String getServerMod() {
|
||||
return serverMod;
|
||||
}
|
||||
|
||||
public void setServerMod(String serverMod) {
|
||||
this.serverMod = serverMod;
|
||||
}
|
||||
|
||||
public String getSafeMod() {
|
||||
return safeMod;
|
||||
}
|
||||
|
||||
public void setSafeMod(String safeMod) {
|
||||
this.safeMod = safeMod;
|
||||
}
|
||||
}
|
|
@ -42,6 +42,7 @@ public class SecureProtocol {
|
|||
private static double deltaError = 1e-6;
|
||||
private static Map<String, float[]> modelMap;
|
||||
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
|
||||
private int retCode;
|
||||
|
||||
public FLClientStatus getStatus() {
|
||||
return status;
|
||||
|
@ -51,6 +52,10 @@ public class SecureProtocol {
|
|||
return featureMask;
|
||||
}
|
||||
|
||||
public int getRetCode() {
|
||||
return retCode;
|
||||
}
|
||||
|
||||
public SecureProtocol() {
|
||||
}
|
||||
|
||||
|
@ -91,6 +96,7 @@ public class SecureProtocol {
|
|||
LOGGER.info("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() + "==============");
|
||||
// round 0
|
||||
status = cipher.exchangeKeys();
|
||||
retCode = cipher.getRetCode();
|
||||
LOGGER.info("[PairWiseMask] ============= RequestExchangeKeys+GetExchangeKeys response: " + status + "============");
|
||||
if (status != FLClientStatus.SUCCESS) {
|
||||
return status;
|
||||
|
@ -98,6 +104,7 @@ public class SecureProtocol {
|
|||
// round 1
|
||||
try {
|
||||
status = cipher.shareSecrets();
|
||||
retCode = cipher.getRetCode();
|
||||
LOGGER.info("[Encrypt] =============RequestShareSecrets+GetShareSecrets response: " + status + "=============");
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask");
|
||||
|
@ -109,6 +116,7 @@ public class SecureProtocol {
|
|||
// round2
|
||||
try {
|
||||
featureMask = cipher.doubleMaskingWeight();
|
||||
retCode = cipher.getRetCode();
|
||||
LOGGER.info("[Encrypt] =============Create double feature mask: SUCCESS=============");
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask");
|
||||
|
@ -151,6 +159,7 @@ public class SecureProtocol {
|
|||
|
||||
public FLClientStatus pwUnmasking() {
|
||||
status = cipher.reconstructSecrets(); // round3
|
||||
retCode = cipher.getRetCode();
|
||||
LOGGER.info("[Encrypt] =============GetClientList+SendReconstructSecret: " + status + "=============");
|
||||
return status;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,230 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import com.mindspore.flclient.model.AdInferBert;
|
||||
import com.mindspore.flclient.model.AdTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import mindspore.schema.FeatureMap;
|
||||
import mindspore.schema.RequestFLJob;
|
||||
import mindspore.schema.ResponseCode;
|
||||
import mindspore.schema.ResponseFLJob;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
public class StartFLJob {
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-jni");
|
||||
}
|
||||
|
||||
private static final Logger LOGGER = Logger.getLogger(StartFLJob.class.toString());
|
||||
|
||||
class RequestStartFLJobBuilder {
|
||||
private RequestFLJob requestFLJob;
|
||||
private FlatBufferBuilder builder;
|
||||
private int nameOffset = 0;
|
||||
private int iteration = 0;
|
||||
private int dataSize = 0;
|
||||
private int timestampOffset = 0;
|
||||
private int idOffset = 0;
|
||||
|
||||
public RequestStartFLJobBuilder() {
|
||||
builder = new FlatBufferBuilder();
|
||||
}
|
||||
|
||||
public RequestStartFLJobBuilder flName(String name) {
|
||||
this.nameOffset = this.builder.createString(name);
|
||||
return this;
|
||||
}
|
||||
|
||||
public RequestStartFLJobBuilder id(String id) {
|
||||
this.idOffset = this.builder.createString(id);
|
||||
return this;
|
||||
}
|
||||
|
||||
public RequestStartFLJobBuilder time(long timestamp) {
|
||||
this.timestampOffset = builder.createString(String.valueOf(timestamp));
|
||||
return this;
|
||||
}
|
||||
|
||||
public RequestStartFLJobBuilder dataSize(int dataSize) {
|
||||
// temp code need confirm
|
||||
this.dataSize = dataSize;
|
||||
LOGGER.info(Common.addTag("[startFLJob] the train data size: " + dataSize));
|
||||
return this;
|
||||
}
|
||||
|
||||
public RequestStartFLJobBuilder iteration(int iteration) {
|
||||
this.iteration = iteration;
|
||||
return this;
|
||||
}
|
||||
|
||||
public byte[] build() {
|
||||
int root = RequestFLJob.createRequestFLJob(this.builder, this.nameOffset, this.idOffset, this.iteration,
|
||||
this.dataSize, this.timestampOffset);
|
||||
builder.finish(root);
|
||||
return builder.sizedByteArray();
|
||||
}
|
||||
}
|
||||
|
||||
private static StartFLJob startFLJob;
|
||||
|
||||
private FLClientStatus status;
|
||||
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private int featureSize;
|
||||
private String nextRequestTime;
|
||||
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
|
||||
|
||||
private StartFLJob() {
|
||||
|
||||
}
|
||||
|
||||
public static StartFLJob getInstance() {
|
||||
if (startFLJob == null) {
|
||||
startFLJob = new StartFLJob();
|
||||
}
|
||||
return startFLJob;
|
||||
}
|
||||
|
||||
public String getNextRequestTime() {
|
||||
return nextRequestTime;
|
||||
}
|
||||
|
||||
public byte[] getRequestStartFLJob(int dataSize, int iteration, long time) {
|
||||
RequestStartFLJobBuilder builder = new RequestStartFLJobBuilder();
|
||||
return builder.flName(flParameter.getFlName())
|
||||
.time(time)
|
||||
.id(localFLParameter.getFlID())
|
||||
.dataSize(dataSize)
|
||||
.iteration(iteration)
|
||||
.build();
|
||||
}
|
||||
|
||||
public int getFeatureSize() {
|
||||
return featureSize;
|
||||
}
|
||||
|
||||
public ArrayList<String> getEncryptFeatureName() {
|
||||
return encryptFeatureName;
|
||||
}
|
||||
|
||||
private FLClientStatus parseResponseAdbert(ResponseFLJob flJob) {
|
||||
int fmCount = flJob.featureMapLength();
|
||||
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
|
||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||
encryptFeatureName.clear();
|
||||
if (fmCount <= 0) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
String featureName = feature.weightFullname();
|
||||
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
|
||||
albertFeatureMaps.add(feature);
|
||||
inferFeatureMaps.add(feature);
|
||||
featureSize += feature.dataLength();
|
||||
encryptFeatureName.add(feature.weightFullname());
|
||||
} else if (localFLParameter.getClassifierWeightName().contains(featureName)) {
|
||||
inferFeatureMaps.add(feature);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference model-----------------"));
|
||||
AdInferBert adInferBert = AdInferBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(adInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------"));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(adTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
private FLClientStatus parseResponseLenet(ResponseFLJob flJob) {
|
||||
int fmCount = flJob.featureMapLength();
|
||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||
encryptFeatureName.clear();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
String featureName = feature.weightFullname();
|
||||
featureMaps.add(feature);
|
||||
featureSize += feature.dataLength();
|
||||
encryptFeatureName.add(featureName);
|
||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
tag = SessionUtil.updateFeatures(trainLenet.getTrainSession(), flParameter.getTrainModelPath(), featureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
public FLClientStatus doResponse(ResponseFLJob flJob) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] return code: " + flJob.retcode()));
|
||||
LOGGER.info(Common.addTag("[startFLJob] reason: " + flJob.reason()));
|
||||
LOGGER.info(Common.addTag("[startFLJob] iteration: " + flJob.iteration()));
|
||||
LOGGER.info(Common.addTag("[startFLJob] is selected: " + flJob.isSelected()));
|
||||
LOGGER.info(Common.addTag("[startFLJob] next request time: " + flJob.nextReqTime()));
|
||||
nextRequestTime = flJob.nextReqTime();
|
||||
LOGGER.info(Common.addTag("[startFLJob] timestamp: " + flJob.timestamp()));
|
||||
int retcode = flJob.retcode();
|
||||
|
||||
switch (retcode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
localFLParameter.setServerMod(flJob.flPlanConfig().serverMode());
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseAdbert>"));
|
||||
parseResponseAdbert(flJob);
|
||||
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseLenet>"));
|
||||
parseResponseLenet(flJob);
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
case (ResponseCode.OutOfTime):
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
case (ResponseCode.SystemError):
|
||||
LOGGER.info(Common.addTag("[startFLJob] catch RequestError or SystemError"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the return <retcode> from server is invalid: " + retcode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus getStatus() {
|
||||
return this.status;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,322 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import com.mindspore.flclient.model.AdInferBert;
|
||||
import com.mindspore.flclient.model.AdTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import mindspore.schema.ResponseGetModel;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Arrays;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||
import static com.mindspore.flclient.LocalFLParameter.ADBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
public class SyncFLJob {
|
||||
private static final Logger LOGGER = Logger.getLogger(SyncFLJob.class.toString());
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private FLJobResultCallback flJobResultCallback = new FLJobResultCallback();
|
||||
|
||||
public SyncFLJob() {
|
||||
}
|
||||
|
||||
public void flJobRun() {
|
||||
localFLParameter.setFlID(flParameter.getClientID());
|
||||
FLLiteClient client = new FLLiteClient();
|
||||
client.initSession();
|
||||
|
||||
FLClientStatus curStatus;
|
||||
do {
|
||||
LOGGER.info(Common.addTag("flName: " + flParameter.getFlName()));
|
||||
int trainDataSize = client.setInput(flParameter.getTrainDataset());
|
||||
LOGGER.info(Common.addTag("train path: " + flParameter.getTrainDataset()));
|
||||
LOGGER.info(Common.addTag("train data size: " + trainDataSize));
|
||||
if (trainDataSize <= 0) {
|
||||
LOGGER.severe(Common.addTag("unsolved error code in <client.setInput>: the return trainDataSize<=0"));
|
||||
curStatus = FLClientStatus.FAILED;
|
||||
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(), client.getRetCode());
|
||||
break;
|
||||
}
|
||||
client.setTrainDataSize(trainDataSize);
|
||||
|
||||
// startFLJob
|
||||
curStatus = client.startFLJob();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
waitSomeTime();
|
||||
curStatus = client.startFLJob();
|
||||
}
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
restart("[startFLJob]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
|
||||
continue;
|
||||
} else if (curStatus == FLClientStatus.FAILED) {
|
||||
failed("[startFLJob]", client.getIteration(), client.getRetCode(), curStatus);
|
||||
break;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] startFLJob succeed, curIteration: " + client.getIteration()));
|
||||
|
||||
// create mask
|
||||
curStatus = client.getFeatureMask();
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
restart("[Encrypt] creatMask", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
|
||||
continue;
|
||||
} else if (curStatus == FLClientStatus.FAILED) {
|
||||
failed("[Encrypt] createMask", client.getIteration(), client.getRetCode(), curStatus);
|
||||
break;
|
||||
}
|
||||
|
||||
// train
|
||||
curStatus = client.localTrain();
|
||||
if (curStatus == FLClientStatus.FAILED) {
|
||||
failed("[train] train", client.getIteration(), client.getRetCode(), curStatus);
|
||||
break;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[train] train succeed"));
|
||||
|
||||
// updateModel
|
||||
curStatus = client.updateModel();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
waitSomeTime();
|
||||
curStatus = client.updateModel();
|
||||
}
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
restart("[updateModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
|
||||
continue;
|
||||
} else if (curStatus == FLClientStatus.FAILED) {
|
||||
failed("[updateModel] updateModel", client.getIteration(), client.getRetCode(), curStatus);
|
||||
break;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[updateModel] updateModel succeed"));
|
||||
|
||||
// unmasking
|
||||
curStatus = client.unMasking();
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
restart("[Encrypt] unmasking", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
|
||||
continue;
|
||||
} else if (curStatus == FLClientStatus.FAILED) {
|
||||
failed("[Encrypt] unmasking", client.getIteration(), client.getRetCode(), curStatus);
|
||||
break;
|
||||
}
|
||||
|
||||
// getModel
|
||||
curStatus = client.getModel();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
waitSomeTime();
|
||||
curStatus = client.getModel();
|
||||
}
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
restart("[getModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
|
||||
continue;
|
||||
} else if (curStatus == FLClientStatus.FAILED) {
|
||||
failed("[getModel] getModel", client.getIteration(), client.getRetCode(), curStatus);
|
||||
break;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] getModel succeed"));
|
||||
|
||||
//evaluate model after getting model from server
|
||||
if (flParameter.getTestDataset().equals("null")) {
|
||||
LOGGER.info(Common.addTag("[evaluate] the testDataset is null, don't evaluate the combine model"));
|
||||
} else {
|
||||
curStatus = client.evaluateModel();
|
||||
if (curStatus == FLClientStatus.FAILED) {
|
||||
failed("[evaluate] evaluate", client.getIteration(), client.getRetCode(), curStatus);
|
||||
break;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluate succeed"));
|
||||
}
|
||||
LOGGER.info(Common.addTag("========================================================the total response of " + client.getIteration() + ": " + curStatus + "======================================================================"));
|
||||
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(), client.getRetCode());
|
||||
} while (client.getIteration() < client.getIterations());
|
||||
client.finalize();
|
||||
LOGGER.info(Common.addTag("flJobRun finish"));
|
||||
flJobResultCallback.onFlJobFinished(flParameter.getFlName(), client.getIterations(), client.getRetCode());
|
||||
}
|
||||
|
||||
public int[] modelInference(String flName, String dataPath, String vocabFile, String idsFile, String modelPath) {
|
||||
int[] labels = new int[0];
|
||||
if (flName.equals(ADBERT)) {
|
||||
AdInferBert adInferBert = AdInferBert.getInstance();
|
||||
LOGGER.info(Common.addTag("===========model inference============="));
|
||||
labels = adInferBert.inferModel(modelPath, dataPath, vocabFile, idsFile);
|
||||
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
|
||||
SessionUtil.free(adInferBert.getTrainSession());
|
||||
LOGGER.info(Common.addTag("[model inference] inference finish"));
|
||||
} else if (flName.equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
LOGGER.info(Common.addTag("===========model inference============="));
|
||||
labels = trainLenet.inferModel(modelPath, dataPath.split(",")[0]);
|
||||
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
|
||||
SessionUtil.free(trainLenet.getTrainSession());
|
||||
LOGGER.info(Common.addTag("[model inference] inference finish"));
|
||||
}
|
||||
if (labels.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[model inference] the return labels is null."));
|
||||
}
|
||||
return labels;
|
||||
}
|
||||
|
||||
public FLClientStatus getModel(boolean useElb, int serverNum, String ip, int port, String flName, String trainModelPath, String inferModelPath, boolean useSSL) {
|
||||
int tag = 0;
|
||||
flParameter.setTrainModelPath(trainModelPath);
|
||||
flParameter.setInferModelPath(inferModelPath);
|
||||
FLClientStatus status = FLClientStatus.SUCCESS;
|
||||
try {
|
||||
if (flName.equals(ADBERT)) {
|
||||
localFLParameter.setServerMod(ServerMod.HYBRID_TRAINING.toString());
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + trainModelPath + " Create Train Session============="));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
tag = adTrainBert.initSessionAndInputs(trainModelPath, true);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " + inferModelPath + " Create inference Session============="));
|
||||
AdInferBert adInferBert = AdInferBert.getInstance();
|
||||
tag = adInferBert.initSessionAndInputs(inferModelPath, false);
|
||||
} else if (flName.equals(LENET)) {
|
||||
localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString());
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + trainModelPath + " Create Train Session============="));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
tag = trainLenet.initSessionAndInputs(trainModelPath, true);
|
||||
}
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
flParameter.setUseSSL(useSSL);
|
||||
FLCommunication flCommunication = FLCommunication.getInstance();
|
||||
String url = Common.generateUrl(useElb, ip, port, serverNum);
|
||||
LOGGER.info(Common.addTag("[getModel] ===========getModel url: " + url + "=============="));
|
||||
GetModel getModelBuf = GetModel.getInstance();
|
||||
byte[] buffer = getModelBuf.getRequestGetModel(flName, 0);
|
||||
byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
|
||||
LOGGER.info(Common.addTag("[getModel] get model request success"));
|
||||
ByteBuffer debugBuffer = ByteBuffer.wrap(message);
|
||||
ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer);
|
||||
status = getModelBuf.doResponse(responseDataBuf);
|
||||
LOGGER.info(Common.addTag("[getModel] success!"));
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + e.getMessage()));
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
if (flName.equals(ADBERT)) {
|
||||
LOGGER.info(Common.addTag("===========free train session============="));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
SessionUtil.free(adTrainBert.getTrainSession());
|
||||
LOGGER.info(Common.addTag("===========free inference session============="));
|
||||
AdInferBert adInferBert = AdInferBert.getInstance();
|
||||
SessionUtil.free(adInferBert.getTrainSession());
|
||||
} else if (flName.equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("===========free session============="));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
SessionUtil.free(trainLenet.getTrainSession());
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
private void waitSomeTime() {
|
||||
if (flParameter.getSleepTime() != 0)
|
||||
Common.sleep(flParameter.getSleepTime());
|
||||
else
|
||||
Common.sleep(SLEEP_TIME);
|
||||
}
|
||||
|
||||
private void waitNextReqTime(String nextReqTime) {
|
||||
if (flParameter.isTimer()) {
|
||||
long waitTime = Common.getWaitTime(nextReqTime);
|
||||
Common.sleep(waitTime);
|
||||
} else {
|
||||
waitSomeTime();
|
||||
}
|
||||
}
|
||||
|
||||
private void restart(String tag, String nextReqTime, int iteration, int retcode) {
|
||||
LOGGER.info(Common.addTag(tag + " out of time: need wait and request startFLJob again"));
|
||||
waitNextReqTime(nextReqTime);
|
||||
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), iteration, retcode);
|
||||
}
|
||||
|
||||
private void failed(String tag, int iteration, int retcode, FLClientStatus curStatus) {
|
||||
LOGGER.info(Common.addTag(tag + " failed"));
|
||||
LOGGER.info(Common.addTag("========================================================the total response of " + iteration + ": " + curStatus + "======================================================================"));
|
||||
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), iteration, retcode);
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
String trainDataset = args[0];
|
||||
String vocabFile = args[1];
|
||||
String idsFile = args[2];
|
||||
String testDataset = args[3];
|
||||
String flName = args[4];
|
||||
String trainModelPath = args[5];
|
||||
String inferModelPath = args[6];
|
||||
String clientID = args[7];
|
||||
String ip = args[8];
|
||||
boolean useSSL = Boolean.parseBoolean(args[9]);
|
||||
int port = Integer.parseInt(args[10]);
|
||||
int timeWindow = Integer.parseInt(args[11]);
|
||||
boolean useElb = Boolean.parseBoolean(args[12]);
|
||||
int serverNum = Integer.parseInt(args[13]);
|
||||
String task = args[14];
|
||||
FLParameter flParameter = FLParameter.getInstance();
|
||||
LOGGER.info(Common.addTag("[args] trainDataset: " + trainDataset));
|
||||
LOGGER.info(Common.addTag("[args] vocabFile: " + vocabFile));
|
||||
LOGGER.info(Common.addTag("[args] idsFile: " + idsFile));
|
||||
LOGGER.info(Common.addTag("[args] testDataset: " + testDataset));
|
||||
LOGGER.info(Common.addTag("[args] flName: " + flName));
|
||||
LOGGER.info(Common.addTag("[args] trainModelPath: " + trainModelPath));
|
||||
LOGGER.info(Common.addTag("[args] inferModelPath: " + inferModelPath));
|
||||
LOGGER.info(Common.addTag("[args] clientID: " + clientID));
|
||||
LOGGER.info(Common.addTag("[args] ip: " + ip));
|
||||
LOGGER.info(Common.addTag("[args] useSSL: " + useSSL));
|
||||
LOGGER.info(Common.addTag("[args] port: " + port));
|
||||
LOGGER.info(Common.addTag("[args] timeWindow: " + timeWindow));
|
||||
LOGGER.info(Common.addTag("[args] useElb: " + useElb));
|
||||
LOGGER.info(Common.addTag("[args] serverNum: " + serverNum));
|
||||
LOGGER.info(Common.addTag("[args] task: " + task));
|
||||
|
||||
flParameter.setClientID(clientID);
|
||||
SyncFLJob syncFLJob = new SyncFLJob();
|
||||
if (task.equals("train")) {
|
||||
flParameter.setTrainDataset(trainDataset);
|
||||
flParameter.setFlName(flName);
|
||||
flParameter.setTrainModelPath(trainModelPath);
|
||||
flParameter.setTestDataset(testDataset);
|
||||
flParameter.setInferModelPath(inferModelPath);
|
||||
flParameter.setIp(ip);
|
||||
flParameter.setUseSSL(useSSL);
|
||||
flParameter.setPort(port);
|
||||
flParameter.setTimeWindow(timeWindow);
|
||||
flParameter.setUseElb(useElb);
|
||||
flParameter.setServerNum(serverNum);
|
||||
if (ADBERT.equals(flName)) {
|
||||
flParameter.setVocabFile(vocabFile);
|
||||
flParameter.setIdsFile(idsFile);
|
||||
}
|
||||
syncFLJob.flJobRun();
|
||||
} else if (task.equals("inference")) {
|
||||
syncFLJob.modelInference(flName, testDataset, vocabFile, idsFile, inferModelPath);
|
||||
} else if (task.equals("getModel")) {
|
||||
syncFLJob.getModel(false, 1, ip, port, flName, trainModelPath, inferModelPath, false);
|
||||
} else {
|
||||
LOGGER.info(Common.addTag("do not do any thing!"));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,210 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import com.mindspore.flclient.model.AdTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import mindspore.schema.FeatureMap;
|
||||
import mindspore.schema.RequestUpdateModel;
|
||||
import mindspore.schema.ResponseCode;
|
||||
import mindspore.schema.ResponseUpdateModel;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Date;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ADBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
public class UpdateModel {
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-jni");
|
||||
}
|
||||
|
||||
class RequestUpdateModelBuilder {
|
||||
private RequestUpdateModel requestUM;
|
||||
private FlatBufferBuilder builder;
|
||||
private int fmOffset = 0;
|
||||
private int nameOffset = 0;
|
||||
private int idOffset = 0;
|
||||
private int timestampOffset = 0;
|
||||
private int iteration = 0;
|
||||
private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT;
|
||||
|
||||
public RequestUpdateModelBuilder(EncryptLevel encryptLevel) {
|
||||
builder = new FlatBufferBuilder();
|
||||
this.encryptLevel = encryptLevel;
|
||||
}
|
||||
|
||||
public RequestUpdateModelBuilder flName(String name) {
|
||||
this.nameOffset = this.builder.createString(name);
|
||||
return this;
|
||||
}
|
||||
|
||||
public RequestUpdateModelBuilder time() {
|
||||
Date date = new Date();
|
||||
long time = date.getTime();
|
||||
this.timestampOffset = builder.createString(String.valueOf(time));
|
||||
return this;
|
||||
}
|
||||
|
||||
public RequestUpdateModelBuilder iteration(int iteration) {
|
||||
this.iteration = iteration;
|
||||
return this;
|
||||
}
|
||||
|
||||
public RequestUpdateModelBuilder id(String id) {
|
||||
this.idOffset = this.builder.createString(id);
|
||||
return this;
|
||||
}
|
||||
|
||||
public RequestUpdateModelBuilder featuresMap(SecureProtocol secureProtocol, int trainDataSize) {
|
||||
ArrayList<String> encryptFeatureName = secureProtocol.getEncryptFeatureName();
|
||||
switch (encryptLevel) {
|
||||
case PW_ENCRYPT:
|
||||
try {
|
||||
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize);
|
||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
|
||||
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
|
||||
return this;
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe("[Encrypt] catch error in maskModel: catch Exception");
|
||||
}
|
||||
case DP_ENCRYPT:
|
||||
try {
|
||||
int[] fmOffsetsPW = secureProtocol.dpMaskModel(builder, trainDataSize);
|
||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
|
||||
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
|
||||
return this;
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe(Common.addTag("[Encrypt] catch error in maskModel: " + e.getMessage()));
|
||||
}
|
||||
case NOT_ENCRYPT:
|
||||
default:
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + flParameter.getFlName()));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
|
||||
if (map.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil.convertTensorToFeatures>"));
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + flParameter.getFlName()));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
if (map.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil.convertTensorToFeatures>"));
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
int featureSize = encryptFeatureName.size();
|
||||
int[] fmOffsets = new int[featureSize];
|
||||
for (int i = 0; i < featureSize; i++) {
|
||||
String key = encryptFeatureName.get(i);
|
||||
float[] data = map.get(key);
|
||||
LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature size: " + data.length));
|
||||
for (int j = 0; j < data.length; j++) {
|
||||
float rawData = data[j];
|
||||
data[j] = data[j] * trainDataSize;
|
||||
}
|
||||
int featureName = builder.createString(key);
|
||||
int weight = FeatureMap.createDataVector(builder, data);
|
||||
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
|
||||
fmOffsets[i] = featureMap;
|
||||
}
|
||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsets);
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
public byte[] build() {
|
||||
RequestUpdateModel.startRequestUpdateModel(this.builder);
|
||||
RequestUpdateModel.addFlName(builder, nameOffset);
|
||||
RequestUpdateModel.addFlId(this.builder, idOffset);
|
||||
RequestUpdateModel.addTimestamp(builder, this.timestampOffset);
|
||||
RequestUpdateModel.addIteration(builder, this.iteration);
|
||||
RequestUpdateModel.addFeatureMap(builder, this.fmOffset);
|
||||
int root = RequestUpdateModel.endRequestUpdateModel(builder);
|
||||
builder.finish(root);
|
||||
return builder.sizedByteArray();
|
||||
}
|
||||
}
|
||||
|
||||
private static final Logger LOGGER = Logger.getLogger(UpdateModel.class.toString());
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private String nextRequestTime;
|
||||
private FLClientStatus status;
|
||||
private static volatile UpdateModel updateModel;
|
||||
|
||||
private UpdateModel() {
|
||||
}
|
||||
|
||||
public static synchronized UpdateModel getInstance() {
|
||||
UpdateModel localRef = updateModel;
|
||||
if (localRef == null) {
|
||||
synchronized (UpdateModel.class) {
|
||||
localRef = updateModel;
|
||||
if (localRef == null) {
|
||||
updateModel = localRef = new UpdateModel();
|
||||
}
|
||||
}
|
||||
}
|
||||
return localRef;
|
||||
}
|
||||
|
||||
public String getNextRequestTime() {
|
||||
return nextRequestTime;
|
||||
}
|
||||
|
||||
public FLClientStatus getStatus() {
|
||||
return status;
|
||||
}
|
||||
|
||||
public byte[] getRequestUpdateFLJob(int iteration, SecureProtocol secureProtocol, int trainDataSize) {
|
||||
RequestUpdateModelBuilder builder = new RequestUpdateModelBuilder(localFLParameter.getEncryptLevel());
|
||||
return builder.flName(flParameter.getFlName()).time().id(localFLParameter.getFlID()).featuresMap(secureProtocol, trainDataSize).iteration(iteration).build();
|
||||
}
|
||||
|
||||
public FLClientStatus doResponse(ResponseUpdateModel response) {
|
||||
LOGGER.info(Common.addTag("[updateModel] ==========updateModel response================"));
|
||||
LOGGER.info(Common.addTag("[updateModel] ==========retcode: " + response.retcode()));
|
||||
LOGGER.info(Common.addTag("[updateModel] ==========reason: " + response.reason()));
|
||||
LOGGER.info(Common.addTag("[updateModel] ==========next request time: " + response.nextReqTime()));
|
||||
nextRequestTime = response.nextReqTime();
|
||||
switch (response.retcode()) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[updateModel] updateModel success"));
|
||||
return FLClientStatus.SUCCESS;
|
||||
case (ResponseCode.OutOfTime):
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
case (ResponseCode.SystemError):
|
||||
LOGGER.warning(Common.addTag("[updateModel] catch RequestError or SystemError"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[updateModel]the return <retcode> from server is invalid: " + response.retcode()));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -45,6 +45,7 @@ public class ClientListReq {
|
|||
private String nextRequestTime;
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private int retCode;
|
||||
|
||||
public ClientListReq() {
|
||||
flCommunication = FLCommunication.getInstance();
|
||||
|
@ -58,6 +59,10 @@ public class ClientListReq {
|
|||
this.nextRequestTime = nextRequestTime;
|
||||
}
|
||||
|
||||
public int getRetCode() {
|
||||
return retCode;
|
||||
}
|
||||
|
||||
public FLClientStatus getClientList(int iteration, List<String> u3ClientList, List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ==============getClientList url: " + url + "=============="));
|
||||
|
@ -81,15 +86,15 @@ public class ClientListReq {
|
|||
}
|
||||
|
||||
public FLClientStatus judgeGetClientList(ReturnClientList bufData, List<String> u3ClientList, List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
|
||||
int retcode = bufData.retcode();
|
||||
retCode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ************** the response of GetClientList **************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] the size of clients: " + bufData.clientsLength()));
|
||||
FLClientStatus status;
|
||||
switch (retcode) {
|
||||
switch (retCode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] GetClientList success"));
|
||||
u3ClientList.clear();
|
||||
|
@ -117,7 +122,7 @@ public class ClientListReq {
|
|||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetClientList"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReturnClientList is invalid: " + retcode));
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnClientList is invalid: " + retCode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,6 +37,7 @@ public class ReconstructSecretReq {
|
|||
private String nextRequestTime;
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private int retCode;
|
||||
|
||||
public String getNextRequestTime() {
|
||||
return nextRequestTime;
|
||||
|
@ -46,6 +47,10 @@ public class ReconstructSecretReq {
|
|||
this.nextRequestTime = nextRequestTime;
|
||||
}
|
||||
|
||||
public int getRetCode() {
|
||||
return retCode;
|
||||
}
|
||||
|
||||
public ReconstructSecretReq() {
|
||||
flCommunication = FLCommunication.getInstance();
|
||||
}
|
||||
|
@ -99,13 +104,13 @@ public class ReconstructSecretReq {
|
|||
}
|
||||
|
||||
public FLClientStatus judgeSendReconstructSecrets(mindspore.schema.ReconstructSecret bufData) {
|
||||
int retcode = bufData.retcode();
|
||||
retCode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of SendReconstructSecrets**************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
|
||||
switch (retcode) {
|
||||
switch (retCode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ReconstructSecrets success"));
|
||||
return FLClientStatus.SUCCESS;
|
||||
|
@ -118,7 +123,7 @@ public class ReconstructSecretReq {
|
|||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in SendReconstructSecrets"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReconstructSecret is invalid: " + retcode));
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReconstructSecret is invalid: " + retCode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ public class AdTrainBert extends AdBert {
|
|||
private static volatile AdTrainBert adTrainBert;
|
||||
|
||||
public static AdTrainBert getInstance() {
|
||||
AdTrainBert localRef = adInferBert;
|
||||
AdTrainBert localRef = adTrainBert;
|
||||
if (localRef == null) {
|
||||
synchronized (AdTrainBert.class) {
|
||||
localRef = adTrainBert;
|
||||
|
|
Loading…
Reference in New Issue