forked from mindspore-Ecosystem/mindspore
!27594 add code of flclient for device decoupling
Merge pull request !27594 from zhoushan33/flclient1213_decoupling
This commit is contained in:
commit
6331c01405
|
@ -0,0 +1,12 @@
|
|||
package com.mindspore.flclient;
|
||||
|
||||
/**
|
||||
* The cpu bind mod.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public enum BindMode {
|
||||
NOT_BINDING_CORE,
|
||||
BIND_LARGE_CORE,
|
||||
BIND_MIDDLE_CORE
|
||||
}
|
|
@ -25,33 +25,44 @@ import com.google.flatbuffers.FlatBufferBuilder;
|
|||
|
||||
import com.mindspore.flclient.cipher.AESEncrypt;
|
||||
import com.mindspore.flclient.cipher.BaseUtil;
|
||||
import com.mindspore.flclient.cipher.CertVerify;
|
||||
import com.mindspore.flclient.cipher.ClientListReq;
|
||||
import com.mindspore.flclient.cipher.KEYAgreement;
|
||||
import com.mindspore.flclient.cipher.Masking;
|
||||
import com.mindspore.flclient.cipher.ReconstructSecretReq;
|
||||
import com.mindspore.flclient.cipher.ShareSecrets;
|
||||
import com.mindspore.flclient.cipher.SignAndVerify;
|
||||
import com.mindspore.flclient.cipher.struct.ClientPublicKey;
|
||||
import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
|
||||
import com.mindspore.flclient.cipher.struct.EncryptShare;
|
||||
import com.mindspore.flclient.cipher.struct.NewArray;
|
||||
import com.mindspore.flclient.cipher.struct.ShareSecret;
|
||||
import com.mindspore.flclient.pki.PkiUtil;
|
||||
|
||||
import mindspore.schema.ClientShare;
|
||||
import mindspore.schema.GetExchangeKeys;
|
||||
import mindspore.schema.GetShareSecrets;
|
||||
import mindspore.schema.RequestAllClientListSign;
|
||||
import mindspore.schema.RequestExchangeKeys;
|
||||
import mindspore.schema.RequestShareSecrets;
|
||||
import mindspore.schema.ResponseClientListSign;
|
||||
import mindspore.schema.ResponseCode;
|
||||
import mindspore.schema.ResponseExchangeKeys;
|
||||
import mindspore.schema.ResponseShareSecrets;
|
||||
import mindspore.schema.ReturnAllClientListSign;
|
||||
import mindspore.schema.ReturnExchangeKeys;
|
||||
import mindspore.schema.ReturnShareSecrets;
|
||||
import mindspore.schema.SendClientListSign;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigInteger;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.security.SecureRandom;
|
||||
import java.security.cert.CertificateEncodingException;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Base64;
|
||||
import java.util.Date;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
@ -93,6 +104,7 @@ public class CipherClient {
|
|||
private ClientListReq clientListReq = new ClientListReq();
|
||||
private ReconstructSecretReq reconstructSecretReq = new ReconstructSecretReq();
|
||||
private int retCode;
|
||||
private Map<String, X509Certificate[]> certificateList = new HashMap<String, X509Certificate[]>();
|
||||
|
||||
/**
|
||||
* Construct function of cipherClient
|
||||
|
@ -423,12 +435,6 @@ public class CipherClient {
|
|||
" please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
||||
byte[] cPK = cKey.get(0);
|
||||
byte[] sPK = sKey.get(0);
|
||||
int cpk = RequestExchangeKeys.createCPkVector(fbBuilder, cPK);
|
||||
int spk = RequestExchangeKeys.createSPkVector(fbBuilder, sPK);
|
||||
|
||||
byte[] indIv = new byte[I_VEC_LEN];
|
||||
byte[] pwIv = new byte[I_VEC_LEN];
|
||||
byte[] thisPwSalt = new byte[SALT_SIZE];
|
||||
|
@ -439,15 +445,24 @@ public class CipherClient {
|
|||
this.individualIv = indIv;
|
||||
this.pwIVec = pwIv;
|
||||
this.pwSalt = thisPwSalt;
|
||||
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
||||
|
||||
int indIvFbs = RequestExchangeKeys.createIndIvVector(fbBuilder, indIv);
|
||||
int pwIvFbs = RequestExchangeKeys.createPwIvVector(fbBuilder, pwIv);
|
||||
int pwSaltFbs = RequestExchangeKeys.createPwSaltVector(fbBuilder, thisPwSalt);
|
||||
|
||||
// for pkiVerify mode
|
||||
int exchangeKeysRoot;
|
||||
byte[] cPK = cKey.get(0);
|
||||
byte[] sPK = sKey.get(0);
|
||||
int cpk = RequestExchangeKeys.createCPkVector(fbBuilder, cPK);
|
||||
int spk = RequestExchangeKeys.createSPkVector(fbBuilder, sPK);
|
||||
int id = fbBuilder.createString(localFLParameter.getFlID());
|
||||
Date date = new Date();
|
||||
long timestamp = date.getTime();
|
||||
String dateTime = String.valueOf(timestamp);
|
||||
int time = fbBuilder.createString(dateTime);
|
||||
String clientID = flParameter.getClientID();
|
||||
|
||||
// start build
|
||||
RequestExchangeKeys.startRequestExchangeKeys(fbBuilder);
|
||||
|
@ -459,7 +474,24 @@ public class CipherClient {
|
|||
RequestExchangeKeys.addIndIv(fbBuilder, indIvFbs);
|
||||
RequestExchangeKeys.addPwIv(fbBuilder, pwIvFbs);
|
||||
RequestExchangeKeys.addPwSalt(fbBuilder, pwSaltFbs);
|
||||
int exchangeKeysRoot = RequestExchangeKeys.endRequestExchangeKeys(fbBuilder);
|
||||
if (flParameter.isPkiVerify()) {
|
||||
// waiting for certificates to take effect
|
||||
int waitTakeEffectTime = 5000;
|
||||
Common.sleep(waitTakeEffectTime);
|
||||
int nSize = 2; // exchange equipment certificate and service equipment
|
||||
String[] pemCertificateChains = transformX509ArrayToPemArray(CertVerify.getX509CertificateChain(clientID));
|
||||
int[] pemList = new int[nSize];
|
||||
for (int i = 0; i < nSize; i++) {
|
||||
pemList[i] = fbBuilder.createString(pemCertificateChains[i]);
|
||||
}
|
||||
int certificatesInt = RequestExchangeKeys.createCertificateChainVector(fbBuilder, pemList);
|
||||
byte[] signature = signPkAndTime(clientID, cPK, sPK, dateTime, iteration);
|
||||
int signed = RequestExchangeKeys.createSignatureVector(fbBuilder, signature);
|
||||
|
||||
RequestExchangeKeys.addSignature(fbBuilder, signed);
|
||||
RequestExchangeKeys.addCertificateChain(fbBuilder, certificatesInt);
|
||||
}
|
||||
exchangeKeysRoot = RequestExchangeKeys.endRequestExchangeKeys(fbBuilder);
|
||||
fbBuilder.finish(exchangeKeysRoot);
|
||||
byte[] msg = fbBuilder.sizedByteArray();
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||
|
@ -472,6 +504,7 @@ public class CipherClient {
|
|||
"request again"));
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
retCode = ResponseCode.OutOfTime;
|
||||
return FLClientStatus.RESTART;
|
||||
}
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
|
@ -518,12 +551,30 @@ public class CipherClient {
|
|||
String dateTime = String.valueOf(timestamp);
|
||||
int time = fbBuilder.createString(dateTime);
|
||||
|
||||
// start build
|
||||
GetExchangeKeys.startGetExchangeKeys(fbBuilder);
|
||||
GetExchangeKeys.addFlId(fbBuilder, id);
|
||||
GetExchangeKeys.addIteration(fbBuilder, iteration);
|
||||
GetExchangeKeys.addTimestamp(fbBuilder, time);
|
||||
int getExchangeKeysRoot = GetExchangeKeys.endGetExchangeKeys(fbBuilder);
|
||||
int getExchangeKeysRoot;
|
||||
byte[] signature = signTimeAndIter(dateTime, iteration);
|
||||
if (signature == null) {
|
||||
LOGGER.severe(Common.addTag("[getExchangeKeys] get signature is null!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
if (signature.length > 0) {
|
||||
int signed = GetExchangeKeys.createSignatureVector(fbBuilder, signature);
|
||||
// start build
|
||||
GetExchangeKeys.startGetExchangeKeys(fbBuilder);
|
||||
GetExchangeKeys.addFlId(fbBuilder, id);
|
||||
GetExchangeKeys.addIteration(fbBuilder, iteration);
|
||||
GetExchangeKeys.addTimestamp(fbBuilder, time);
|
||||
GetExchangeKeys.addSignature(fbBuilder, signed);
|
||||
getExchangeKeysRoot = GetExchangeKeys.endGetExchangeKeys(fbBuilder);
|
||||
} else {
|
||||
// start build
|
||||
GetExchangeKeys.startGetExchangeKeys(fbBuilder);
|
||||
GetExchangeKeys.addFlId(fbBuilder, id);
|
||||
GetExchangeKeys.addIteration(fbBuilder, iteration);
|
||||
GetExchangeKeys.addTimestamp(fbBuilder, time);
|
||||
getExchangeKeysRoot = GetExchangeKeys.endGetExchangeKeys(fbBuilder);
|
||||
}
|
||||
|
||||
fbBuilder.finish(getExchangeKeysRoot);
|
||||
byte[] msg = fbBuilder.sizedByteArray();
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||
|
@ -535,6 +586,7 @@ public class CipherClient {
|
|||
"request again"));
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
retCode = ResponseCode.OutOfTime;
|
||||
return FLClientStatus.RESTART;
|
||||
}
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
|
@ -558,22 +610,44 @@ public class CipherClient {
|
|||
clientPublicKeyList.clear();
|
||||
u1ClientList.clear();
|
||||
int length = bufData.remotePublickeysLength();
|
||||
|
||||
for (int i = 0; i < length; i++) {
|
||||
ByteBuffer bufCpk = bufData.remotePublickeys(i).cPkAsByteBuffer();
|
||||
ByteBuffer bufSpk = bufData.remotePublickeys(i).sPkAsByteBuffer();
|
||||
int sizeCpk = bufData.remotePublickeys(i).cPkLength();
|
||||
int sizeSpk = bufData.remotePublickeys(i).sPkLength();
|
||||
byte[] bufCpkList = byteBufferToList(bufCpk, sizeCpk);
|
||||
byte[] bufSpkList = byteBufferToList(bufSpk, sizeSpk);
|
||||
// copy bufCpkList and bufSpkList
|
||||
byte[] cPkByte = bufCpkList.clone();
|
||||
byte[] sPkByte = bufSpkList.clone();
|
||||
|
||||
// check signature
|
||||
boolean isPkiVerify = flParameter.isPkiVerify();
|
||||
if (isPkiVerify) {
|
||||
FLClientStatus checkResult = checkSignature(bufData, i, cPkByte, sPkByte);
|
||||
if (checkResult == FLClientStatus.FAILED) {
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
ClientPublicKey publicKey = new ClientPublicKey();
|
||||
String srcFlId = bufData.remotePublickeys(i).flId();
|
||||
publicKey.setFlID(srcFlId);
|
||||
ByteBuffer bufCpk = bufData.remotePublickeys(i).cPkAsByteBuffer();
|
||||
int sizeCpk = bufData.remotePublickeys(i).cPkLength();
|
||||
ByteBuffer bufSpk = bufData.remotePublickeys(i).sPkAsByteBuffer();
|
||||
int sizeSpk = bufData.remotePublickeys(i).sPkLength();
|
||||
ByteBuffer bufPwIv = bufData.remotePublickeys(i).pwIvAsByteBuffer();
|
||||
int sizePwIv = bufData.remotePublickeys(i).pwIvLength();
|
||||
ByteBuffer bufPwSalt = bufData.remotePublickeys(i).pwSaltAsByteBuffer();
|
||||
int sizePwSalt = bufData.remotePublickeys(i).pwSaltLength();
|
||||
publicKey.setCPK(byteToArray(bufCpk, sizeCpk));
|
||||
publicKey.setSPK(byteToArray(bufSpk, sizeSpk));
|
||||
publicKey.setPwIv(byteToArray(bufPwIv, sizePwIv));
|
||||
publicKey.setPwSalt(byteToArray(bufPwSalt, sizePwSalt));
|
||||
NewArray<byte[]> bufCpkArray = new NewArray<>();
|
||||
bufCpkArray.setSize(sizeCpk);
|
||||
bufCpkArray.setArray(bufCpkList);
|
||||
NewArray<byte[]> bufSpkArray = new NewArray<>();
|
||||
bufSpkArray.setSize(sizeSpk);
|
||||
bufSpkArray.setArray(bufSpkList);
|
||||
publicKey.setCPK(bufCpkArray);
|
||||
publicKey.setSPK(bufSpkArray);
|
||||
clientPublicKeyList.put(srcFlId, publicKey);
|
||||
u1ClientList.add(srcFlId);
|
||||
}
|
||||
|
@ -598,6 +672,86 @@ public class CipherClient {
|
|||
}
|
||||
}
|
||||
|
||||
private FLClientStatus checkSignature(ReturnExchangeKeys bufData, int dataIndex, byte[] cPkByte, byte[] sPkByte) {
|
||||
ByteBuffer signature = bufData.remotePublickeys(dataIndex).signatureAsByteBuffer();
|
||||
byte[] sigByte = new byte[signature.remaining()];
|
||||
signature.get(sigByte);
|
||||
int certifyNum = bufData.remotePublickeys(dataIndex).certificateChainLength();
|
||||
String[] pemCerts = new String[certifyNum];
|
||||
for (int certIndex = 0; certIndex < certifyNum; certIndex++) {
|
||||
pemCerts[certIndex] = bufData.remotePublickeys(dataIndex).certificateChain(certIndex);
|
||||
}
|
||||
|
||||
X509Certificate[] x509Certificates = CertVerify.transformPemArrayToX509Array(pemCerts);
|
||||
if (x509Certificates.length < 2) {
|
||||
LOGGER.severe(Common.addTag("the length of x509Certificates is not valid, should be >= 2"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
String certificateHash = PkiUtil.genHashFromCer(x509Certificates[1]);
|
||||
LOGGER.info(Common.addTag("Get certificate hash success!"));
|
||||
|
||||
// check srcId
|
||||
String srcFlId = bufData.remotePublickeys(dataIndex).flId();
|
||||
if (certificateHash.equals(srcFlId)) {
|
||||
LOGGER.info(Common.addTag("Check flID success and source flID is:" + srcFlId));
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("Check flID failed!" + "source flID: " + srcFlId + "Hash ID from certificate:" +
|
||||
" " + certificateHash.equals(srcFlId)));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
|
||||
certificateList.put(srcFlId, x509Certificates);
|
||||
String timestamp = bufData.remotePublickeys(dataIndex).timestamp();
|
||||
String clientID = flParameter.getClientID();
|
||||
if (!verifySignature(clientID, x509Certificates, sigByte, cPkByte, sPkByte, timestamp, iteration)) {
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] FlID: " + srcFlId +
|
||||
", signature authentication failed"));
|
||||
return FLClientStatus.FAILED;
|
||||
} else {
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] Verify signature success!"));
|
||||
}
|
||||
|
||||
// check iteration and timestamp
|
||||
int remoteIter = bufData.iteration();
|
||||
FLClientStatus iterTimeCheck = checkIterAndTimestamp(remoteIter, timestamp);
|
||||
if (iterTimeCheck == FLClientStatus.FAILED) {
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
private FLClientStatus checkIterAndTimestamp(int remoteIter, String timestamp) {
|
||||
if (remoteIter != iteration) {
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] iteration check failed. Remote iteration of client: " + "is "
|
||||
+ remoteIter + ", which is not consistent with current iteration:" + iteration));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
Date date = new Date();
|
||||
long currentTimeStamp = date.getTime();
|
||||
if (timestamp == null) {
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] Received timeStamp is null,please check it!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
long remoteTimeStamp = Long.parseLong(timestamp);
|
||||
long validIterInterval = flParameter.getValidInterval();
|
||||
if (currentTimeStamp - remoteTimeStamp > validIterInterval || currentTimeStamp < remoteTimeStamp) {
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] timeStamp check failed! The difference between" +
|
||||
" remote timestamp and current timestamp is beyond valid iteration interval!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
private byte[] byteBufferToList(ByteBuffer buf, int size) {
|
||||
byte[] array = new byte[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
byte word = buf.get();
|
||||
array[i] = word;
|
||||
}
|
||||
return array;
|
||||
}
|
||||
|
||||
private FLClientStatus requestShareSecrets() {
|
||||
FLClientStatus status = genIndividualSecret();
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
|
@ -642,25 +796,44 @@ public class CipherClient {
|
|||
}
|
||||
int encryptedSharesFbs = RequestShareSecrets.createEncryptedSharesVector(fbBuilder, add);
|
||||
|
||||
// start build
|
||||
RequestShareSecrets.startRequestShareSecrets(fbBuilder);
|
||||
RequestShareSecrets.addFlId(fbBuilder, id);
|
||||
RequestShareSecrets.addEncryptedShares(fbBuilder, encryptedSharesFbs);
|
||||
RequestShareSecrets.addIteration(fbBuilder, iteration);
|
||||
RequestShareSecrets.addTimestamp(fbBuilder, time);
|
||||
int requestShareSecretsRoot = RequestShareSecrets.endRequestShareSecrets(fbBuilder);
|
||||
int requestShareSecretsRoot;
|
||||
byte[] signature = signTimeAndIter(dateTime, iteration);
|
||||
if (signature == null) {
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] get signature is null!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
if (signature.length > 0) {
|
||||
int signed = RequestShareSecrets.createSignatureVector(fbBuilder, signature);
|
||||
// start build
|
||||
RequestShareSecrets.startRequestShareSecrets(fbBuilder);
|
||||
RequestShareSecrets.addFlId(fbBuilder, id);
|
||||
RequestShareSecrets.addEncryptedShares(fbBuilder, encryptedSharesFbs);
|
||||
RequestShareSecrets.addIteration(fbBuilder, iteration);
|
||||
RequestShareSecrets.addTimestamp(fbBuilder, time);
|
||||
RequestShareSecrets.addSignature(fbBuilder, signed);
|
||||
requestShareSecretsRoot = RequestShareSecrets.endRequestShareSecrets(fbBuilder);
|
||||
} else {
|
||||
// start build
|
||||
RequestShareSecrets.startRequestShareSecrets(fbBuilder);
|
||||
RequestShareSecrets.addFlId(fbBuilder, id);
|
||||
RequestShareSecrets.addEncryptedShares(fbBuilder, encryptedSharesFbs);
|
||||
RequestShareSecrets.addIteration(fbBuilder, iteration);
|
||||
RequestShareSecrets.addTimestamp(fbBuilder, time);
|
||||
requestShareSecretsRoot = RequestShareSecrets.endRequestShareSecrets(fbBuilder);
|
||||
}
|
||||
|
||||
fbBuilder.finish(requestShareSecretsRoot);
|
||||
byte[] msg = fbBuilder.sizedByteArray();
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||
flParameter.getDomainName());
|
||||
try {
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||
flParameter.getDomainName());
|
||||
byte[] responseData = flCommunication.syncRequest(url + "/shareSecrets", msg);
|
||||
if (!Common.isSeverReady(responseData)) {
|
||||
LOGGER.info(Common.addTag("[requestShareSecrets] the server is not ready now, need wait some time" +
|
||||
" " +
|
||||
"and request again"));
|
||||
" and request again"));
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
retCode = ResponseCode.OutOfTime;
|
||||
return FLClientStatus.RESTART;
|
||||
}
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
|
@ -708,11 +881,27 @@ public class CipherClient {
|
|||
String dateTime = String.valueOf(timestamp);
|
||||
int time = fbBuilder.createString(dateTime);
|
||||
|
||||
GetShareSecrets.startGetShareSecrets(fbBuilder);
|
||||
GetShareSecrets.addFlId(fbBuilder, id);
|
||||
GetShareSecrets.addIteration(fbBuilder, iteration);
|
||||
GetShareSecrets.addTimestamp(fbBuilder, time);
|
||||
int getShareSecrets = GetShareSecrets.endGetShareSecrets(fbBuilder);
|
||||
int getShareSecrets;
|
||||
byte[] signature = signTimeAndIter(dateTime, iteration);
|
||||
if (signature == null) {
|
||||
LOGGER.severe(Common.addTag("[getShareSecrets] get signature is null!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
if (signature.length > 0) {
|
||||
int signed = GetShareSecrets.createSignatureVector(fbBuilder, signature);
|
||||
GetShareSecrets.startGetShareSecrets(fbBuilder);
|
||||
GetShareSecrets.addFlId(fbBuilder, id);
|
||||
GetShareSecrets.addIteration(fbBuilder, iteration);
|
||||
GetShareSecrets.addTimestamp(fbBuilder, time);
|
||||
GetShareSecrets.addSignature(fbBuilder, signed);
|
||||
getShareSecrets = GetShareSecrets.endGetShareSecrets(fbBuilder);
|
||||
} else {
|
||||
GetShareSecrets.startGetShareSecrets(fbBuilder);
|
||||
GetShareSecrets.addFlId(fbBuilder, id);
|
||||
GetShareSecrets.addIteration(fbBuilder, iteration);
|
||||
GetShareSecrets.addTimestamp(fbBuilder, time);
|
||||
getShareSecrets = GetShareSecrets.endGetShareSecrets(fbBuilder);
|
||||
}
|
||||
fbBuilder.finish(getShareSecrets);
|
||||
byte[] msg = fbBuilder.sizedByteArray();
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||
|
@ -724,6 +913,7 @@ public class CipherClient {
|
|||
"request again"));
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
retCode = ResponseCode.OutOfTime;
|
||||
return FLClientStatus.RESTART;
|
||||
}
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
|
@ -864,6 +1054,19 @@ public class CipherClient {
|
|||
}
|
||||
retCode = clientListReq.getRetCode();
|
||||
|
||||
// clientListCheck
|
||||
if (flParameter.isPkiVerify()) {
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] The mode is pkiVerify mode, start clientList check ..."));
|
||||
curStatus = clientListCheck();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
Common.sleep(SLEEP_TIME);
|
||||
curStatus = clientListCheck();
|
||||
}
|
||||
if (curStatus != FLClientStatus.SUCCESS) {
|
||||
return curStatus;
|
||||
}
|
||||
}
|
||||
|
||||
// SendReconstructSecret
|
||||
curStatus = reconstructSecretReq.sendReconstructSecret(decryptShareSecretsList, u3ClientList, iteration);
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
|
@ -876,4 +1079,340 @@ public class CipherClient {
|
|||
retCode = reconstructSecretReq.getRetCode();
|
||||
return curStatus;
|
||||
}
|
||||
|
||||
private byte[] signPkAndTime(String clientID, byte[] cPK, byte[] sPK, String time, int iterNum) {
|
||||
// concatenate cPK, sPK and time
|
||||
byte[] concatData = concatenateData(cPK, sPK, time, iterNum);
|
||||
LOGGER.info("concatenate data success!");
|
||||
// signature
|
||||
return SignAndVerify.signData(clientID, concatData);
|
||||
}
|
||||
|
||||
private static byte[] concatenateData(byte[] cPK, byte[] sPK, String time, int iterNum) {
|
||||
// concatenate cPK, sPK and time
|
||||
if (time == null) {
|
||||
LOGGER.severe(Common.addTag("[concatenateData] input time is null, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
byte[] byteTime = time.getBytes(StandardCharsets.UTF_8);
|
||||
String iterString = String.valueOf(iterNum);
|
||||
byte[] byteIter = iterString.getBytes(StandardCharsets.UTF_8);
|
||||
int concatLength = cPK.length + sPK.length + byteTime.length + byteIter.length;
|
||||
byte[] concatData = new byte[concatLength];
|
||||
|
||||
int offset = 0;
|
||||
System.arraycopy(cPK, 0, concatData, offset, cPK.length);
|
||||
|
||||
offset += cPK.length;
|
||||
System.arraycopy(sPK, 0, concatData, offset, sPK.length);
|
||||
|
||||
offset += sPK.length;
|
||||
System.arraycopy(byteTime, 0, concatData, offset, byteTime.length);
|
||||
|
||||
offset += byteTime.length;
|
||||
System.arraycopy(byteIter, 0, concatData, offset, byteIter.length);
|
||||
return concatData;
|
||||
}
|
||||
|
||||
private static byte[] concatenateIterAndTime(String time, int iterNum) {
|
||||
// concatenate cPK, sPK and time
|
||||
byte[] byteTime = time.getBytes(StandardCharsets.UTF_8);
|
||||
String iterString = String.valueOf(iterNum);
|
||||
byte[] byteIter = iterString.getBytes(StandardCharsets.UTF_8);
|
||||
int concatLength = byteTime.length + byteIter.length;
|
||||
byte[] concatData = new byte[concatLength];
|
||||
int offset = 0;
|
||||
System.arraycopy(byteTime, 0, concatData, offset, byteTime.length);
|
||||
offset += byteTime.length;
|
||||
System.arraycopy(byteIter, 0, concatData, offset, byteIter.length);
|
||||
return concatData;
|
||||
}
|
||||
|
||||
private static boolean verifySignature(String clientID, X509Certificate[] x509Certificates, byte[] signature,
|
||||
byte[] cPK, byte[] sPK, String timestamp, int iteration) {
|
||||
byte[] concatData = concatenateData(cPK, sPK, timestamp, iteration);
|
||||
return SignAndVerify.verifySignatureByCert(clientID, x509Certificates, concatData, signature);
|
||||
}
|
||||
|
||||
private FLClientStatus clientListCheck() {
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ==================== ClientListCheck ======================"));
|
||||
FLClientStatus curStatus;
|
||||
// send signed clientList
|
||||
|
||||
curStatus = sendClientListSign();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
Common.sleep(SLEEP_TIME);
|
||||
curStatus = sendClientListSign();
|
||||
}
|
||||
if (curStatus != FLClientStatus.SUCCESS) {
|
||||
return curStatus;
|
||||
}
|
||||
|
||||
// get signed clientList
|
||||
|
||||
curStatus = getAllClientListSign();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
Common.sleep(SLEEP_TIME);
|
||||
curStatus = getAllClientListSign();
|
||||
}
|
||||
return curStatus;
|
||||
}
|
||||
|
||||
private FLClientStatus sendClientListSign() {
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ==============request flID: " +
|
||||
localFLParameter.getFlID() + "=============="));
|
||||
genDHKeyPairs();
|
||||
List<String> clientList = u3ClientList;
|
||||
int listSize = u3ClientList.size();
|
||||
if (listSize == 0) {
|
||||
LOGGER.severe("[Encrypt] u3List is empty, please check!");
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
|
||||
// send signature
|
||||
byte[] clientListByte = transStringListToByte(clientList);
|
||||
byte[] listHash = SignAndVerify.getSHA256(clientListByte);
|
||||
String clientID = flParameter.getClientID();
|
||||
byte[] signature = SignAndVerify.signData(clientID, listHash);
|
||||
if (signature == null) {
|
||||
LOGGER.severe(Common.addTag("[sendClientListSign] the returned signature is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
||||
int signed = RequestExchangeKeys.createSignatureVector(fbBuilder, signature);
|
||||
|
||||
int sendClientListRoot;
|
||||
Date date = new Date();
|
||||
long timestamp = date.getTime();
|
||||
String dateTime = String.valueOf(timestamp);
|
||||
byte[] reqSign = signTimeAndIter(dateTime, iteration);
|
||||
String flID = localFLParameter.getFlID();
|
||||
int id = fbBuilder.createString(flID);
|
||||
int time = fbBuilder.createString(dateTime);
|
||||
if (signature.length > 0) {
|
||||
int reqSigned = SendClientListSign.createSignatureVector(fbBuilder, reqSign);
|
||||
sendClientListRoot = SendClientListSign.createSendClientListSign(fbBuilder, id, iteration, time, signed,
|
||||
reqSigned);
|
||||
} else {
|
||||
sendClientListRoot = SendClientListSign.createSendClientListSign(fbBuilder, id, iteration, time, signed, 0);
|
||||
}
|
||||
|
||||
fbBuilder.finish(sendClientListRoot);
|
||||
byte[] msg = fbBuilder.sizedByteArray();
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||
flParameter.getDomainName());
|
||||
try {
|
||||
byte[] responseData = flCommunication.syncRequest(url + "/pushListSign", msg);
|
||||
|
||||
if (!Common.isSeverReady(responseData)) {
|
||||
LOGGER.info(Common.addTag("[sendClientListSign] the server is not ready now, need wait some time and " +
|
||||
"request again"));
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
retCode = ResponseCode.OutOfTime;
|
||||
return FLClientStatus.RESTART;
|
||||
}
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
ResponseClientListSign responseClientListSign =
|
||||
ResponseClientListSign.getRootAsResponseClientListSign(buffer);
|
||||
return judgeRequestClientList(responseClientListSign);
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
private FLClientStatus judgeRequestClientList(ResponseClientListSign bufData) {
|
||||
retCode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestClientListSign**************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
|
||||
switch (retCode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] RequestClientListSign success"));
|
||||
return FLClientStatus.SUCCESS;
|
||||
case (ResponseCode.OutOfTime):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] RequestClientListSign out of time: need wait and request " +
|
||||
"startFLJob again"));
|
||||
setNextRequestTime(bufData.nextReqTime());
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
case (ResponseCode.SystemError):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] catch RequestError or SystemError in RequestClientListSign"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in RequestClientListSign" +
|
||||
" is invalid: " + retCode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
private FLClientStatus getAllClientListSign() {
|
||||
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
||||
int id = fbBuilder.createString(localFLParameter.getFlID());
|
||||
Date date = new Date();
|
||||
long timestamp = date.getTime();
|
||||
String dateTime = String.valueOf(timestamp);
|
||||
int time = fbBuilder.createString(dateTime);
|
||||
|
||||
int requestAllClientListSign;
|
||||
byte[] signature = signTimeAndIter(dateTime, iteration);
|
||||
if (signature.length > 0) {
|
||||
int signed = RequestAllClientListSign.createSignatureVector(fbBuilder, signature);
|
||||
requestAllClientListSign = RequestAllClientListSign.createRequestAllClientListSign(fbBuilder, id,
|
||||
iteration, time, signed);
|
||||
} else {
|
||||
requestAllClientListSign = RequestAllClientListSign.createRequestAllClientListSign(fbBuilder, id,
|
||||
iteration, time, 0);
|
||||
}
|
||||
|
||||
fbBuilder.finish(requestAllClientListSign);
|
||||
byte[] msg = fbBuilder.sizedByteArray();
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||
flParameter.getDomainName());
|
||||
try {
|
||||
byte[] responseData = flCommunication.syncRequest(url + "/getListSign", msg);
|
||||
|
||||
if (!Common.isSeverReady(responseData)) {
|
||||
LOGGER.info(Common.addTag("[getAllClientListSign] the server is not ready now, need wait some time " +
|
||||
"and request again"));
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
retCode = ResponseCode.OutOfTime;
|
||||
return FLClientStatus.RESTART;
|
||||
}
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
ReturnAllClientListSign returnAllClientList =
|
||||
ReturnAllClientListSign.getRootAsReturnAllClientListSign(buffer);
|
||||
return judgeAllClientList(returnAllClientList);
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
private FLClientStatus judgeAllClientList(ReturnAllClientListSign bufData) {
|
||||
retCode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetAllClientsList**************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
|
||||
switch (retCode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] GetAllClientList success"));
|
||||
int length = bufData.clientListSignLength();
|
||||
String clientID = flParameter.getClientID();
|
||||
String localFlID = localFLParameter.getFlID();
|
||||
byte[] localClientList = transStringListToByte(u3ClientList);
|
||||
byte[] localListHash = SignAndVerify.getSHA256(localClientList);
|
||||
for (int i = 0; i < length; i++) {
|
||||
// verify signature
|
||||
ByteBuffer signature = bufData.clientListSign(i).signatureAsByteBuffer();
|
||||
byte[] sigByte = new byte[signature.remaining()];
|
||||
signature.get(sigByte);
|
||||
if (bufData.clientListSign(i).flId() == null) {
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] get flID failed!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
String srcFlId = bufData.clientListSign(i).flId();
|
||||
X509Certificate[] remoteCertificates = certificateList.get(srcFlId);
|
||||
if (localFlID.equals(srcFlId)) {
|
||||
continue;
|
||||
} // Do not verify itself
|
||||
if (!SignAndVerify.verifySignatureByCert(clientID, remoteCertificates, localListHash, sigByte)) {
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] FlID: " + srcFlId +
|
||||
", signature authentication failed"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
case (ResponseCode.SucNotReady):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request " +
|
||||
"GetAllClientsList again!"));
|
||||
return FLClientStatus.WAIT;
|
||||
case (ResponseCode.OutOfTime):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] GetAllClientsList out of time: need wait and request " +
|
||||
"startFLJob again"));
|
||||
setNextRequestTime(bufData.nextReqTime());
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
case (ResponseCode.SystemError):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetAllClientsList"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnAllClientList " +
|
||||
"is invalid: " + retCode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
private byte[] transStringListToByte(List<String> stringList) {
|
||||
int byteNum = 0;
|
||||
for (String value : stringList) {
|
||||
byte[] stringByte = value.getBytes(StandardCharsets.UTF_8);
|
||||
byteNum += stringByte.length;
|
||||
}
|
||||
byte[] concatData = new byte[byteNum];
|
||||
int offset = 0;
|
||||
for (String str : stringList) {
|
||||
byte[] stringByte = str.getBytes(StandardCharsets.UTF_8);
|
||||
System.arraycopy(stringByte, 0, concatData, offset, stringByte.length);
|
||||
offset += stringByte.length;
|
||||
}
|
||||
return concatData;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add signature on timestamp and iteration
|
||||
*
|
||||
* @param dateTime the timestamp of data
|
||||
* @param iteration iteration number
|
||||
* @return signed time and iteration
|
||||
*/
|
||||
public static byte[] signTimeAndIter(String dateTime, int iteration) {
|
||||
// signature
|
||||
FLParameter flParameter = FLParameter.getInstance();
|
||||
String clientID = flParameter.getClientID();
|
||||
boolean isPkiVerify = flParameter.isPkiVerify();
|
||||
byte[] signature = new byte[0];
|
||||
if (isPkiVerify) {
|
||||
LOGGER.info(Common.addTag("ClientID is:" + clientID));
|
||||
byte[] concatData = concatenateIterAndTime(dateTime, iteration);
|
||||
signature = SignAndVerify.signData(clientID, concatData);
|
||||
}
|
||||
return signature;
|
||||
}
|
||||
|
||||
private static String transformX509ToPem(X509Certificate x509Certificate) {
|
||||
if (x509Certificate == null) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] x509Certificate is null, please check!"));
|
||||
return null;
|
||||
}
|
||||
String pemCert;
|
||||
try {
|
||||
byte[] derCert = x509Certificate.getEncoded();
|
||||
pemCert = new String(Base64.getEncoder().encode(derCert));
|
||||
} catch (CertificateEncodingException e) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] catch Exception: " + e.getMessage()));
|
||||
return null;
|
||||
}
|
||||
return pemCert;
|
||||
}
|
||||
|
||||
private static String[] transformX509ArrayToPemArray(X509Certificate[] x509Certificates) {
|
||||
if (x509Certificates == null || x509Certificates.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] certificateChains is null or empty, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
int nSize = x509Certificates.length;
|
||||
String[] pemCerts = new String[nSize];
|
||||
for (int i = 0; i < nSize; ++i) {
|
||||
String pemCert = transformX509ToPem(x509Certificates[i]);
|
||||
pemCerts[i] = pemCert;
|
||||
}
|
||||
return pemCerts;
|
||||
}
|
||||
}
|
|
@ -16,6 +16,14 @@
|
|||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import com.mindspore.flclient.model.AlInferBert;
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.Client;
|
||||
import com.mindspore.flclient.model.ClientManager;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.Status;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import mindspore.schema.ResponseCode;
|
||||
import org.bouncycastle.crypto.BlockCipher;
|
||||
import org.bouncycastle.crypto.engines.AESEngine;
|
||||
import org.bouncycastle.crypto.prng.SP800SecureRandomBuilder;
|
||||
|
@ -33,6 +41,9 @@ import java.util.logging.Logger;
|
|||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
/**
|
||||
* Define basic global methods used in federated learning task.
|
||||
*
|
||||
|
@ -44,6 +55,9 @@ public class Common {
|
|||
*/
|
||||
public static final String LOG_TITLE = "<FLClient> ";
|
||||
|
||||
public static final String LOG_DEPRECATED = "This method will be deprecated in the next version, it is " +
|
||||
"recommended to " +
|
||||
"use the latest method according to the use cases in the official website tutorial";
|
||||
/**
|
||||
* The list of trust flName.
|
||||
*/
|
||||
|
@ -63,8 +77,16 @@ public class Common {
|
|||
* The tag when server is not ready.
|
||||
*/
|
||||
public static final String JOB_NOT_AVAILABLE = "The server's training job is disabled or finished.";
|
||||
|
||||
/**
|
||||
* use to stop job.
|
||||
*/
|
||||
private static final Object STOP_OBJECT = new Object();
|
||||
private static final Logger LOGGER = Logger.getLogger(Common.class.toString());
|
||||
private static List<String> envTrustList = new ArrayList<>(Arrays.asList("x86", "android"));
|
||||
private static SecureRandom secureRandom;
|
||||
private static int iteration;
|
||||
private static boolean isHttps;
|
||||
|
||||
/**
|
||||
* Generate the URL for device-sever interaction
|
||||
|
@ -107,6 +129,7 @@ public class Common {
|
|||
throw new IllegalArgumentException();
|
||||
}
|
||||
String tag = domainName.split("//")[0] + "//";
|
||||
setIsHttps(domainName.split("//")[0].split(":")[0]);
|
||||
Random rand = new Random();
|
||||
int randomNum = rand.nextInt(100000) % serverNum + port;
|
||||
url = tag + ip + ":" + String.valueOf(randomNum);
|
||||
|
@ -167,6 +190,17 @@ public class Common {
|
|||
return (FL_NAME_TRUST_LIST.contains(flName));
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the deploy environment set by user is in the trust list.
|
||||
*
|
||||
* @param env the deploy environment for federated learning task set by user.
|
||||
* @return boolean value, true indicates the deploy environment set by user is valid, false indicates the deploy
|
||||
* environment set by user is not valid.
|
||||
*/
|
||||
public static boolean checkEnv(String env) {
|
||||
return (envTrustList.contains(env));
|
||||
}
|
||||
|
||||
/**
|
||||
* Check whether the sslProtocol set by user is in the trust list.
|
||||
*
|
||||
|
@ -184,11 +218,23 @@ public class Common {
|
|||
* @param millis the waiting time (ms).
|
||||
*/
|
||||
public static void sleep(long millis) {
|
||||
try {
|
||||
Thread.sleep(millis); // 1000 milliseconds is one second.
|
||||
} catch (InterruptedException ex) {
|
||||
LOGGER.severe(addTag("[sleep] catch InterruptedException: " + ex.getMessage()));
|
||||
Thread.currentThread().interrupt();
|
||||
if (millis > 0) {
|
||||
try {
|
||||
synchronized (STOP_OBJECT) {
|
||||
STOP_OBJECT.wait(millis); // 1000 milliseconds is one second.
|
||||
}
|
||||
} catch (InterruptedException ex) {
|
||||
LOGGER.severe(addTag("[sleep] catch InterruptedException: " + ex.getMessage()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Use to stop the wait method of a Object.
|
||||
*/
|
||||
public static void notifyObject() {
|
||||
synchronized (STOP_OBJECT) {
|
||||
STOP_OBJECT.notify();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -261,10 +307,20 @@ public class Common {
|
|||
String messageStr = new String(message);
|
||||
if (messageStr.contains(SAFE_MOD)) {
|
||||
LOGGER.info(Common.addTag("[isSeverReady] " + SAFE_MOD + ", need wait some time and request again"));
|
||||
if (messageStr.split(":").length == 2) {
|
||||
iteration = Integer.parseInt(messageStr.split(":")[1]);
|
||||
} else {
|
||||
LOGGER.info(Common.addTag("[isSeverReady] the server does not return the current iteration."));
|
||||
}
|
||||
return false;
|
||||
} else if (messageStr.contains(JOB_NOT_AVAILABLE)) {
|
||||
LOGGER.info(Common.addTag("[isSeverReady] " + JOB_NOT_AVAILABLE + ", need wait some time and request " +
|
||||
"again"));
|
||||
if (messageStr.split(":").length == 2) {
|
||||
iteration = Integer.parseInt(messageStr.split(":")[1]);
|
||||
} else {
|
||||
LOGGER.info(Common.addTag("[isSeverReady] the server does not return the current iteration."));
|
||||
}
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
|
@ -307,7 +363,7 @@ public class Common {
|
|||
* Check whether the path set by user exists.
|
||||
*
|
||||
* @param path the path set by user.
|
||||
* @return boolean value, true indicates the path is exist, false indicates the path is not exist
|
||||
* @return boolean value, true indicates the path is exist, false indicates the path does not exist
|
||||
*/
|
||||
public static boolean checkPath(String path) {
|
||||
if (path == null) {
|
||||
|
@ -323,7 +379,7 @@ public class Common {
|
|||
LOGGER.info(addTag("[check path " + i + "] " + paths[i]));
|
||||
File file = new File(paths[i]);
|
||||
if (!file.exists()) {
|
||||
LOGGER.severe(Common.addTag("[checkPath] the path is not exist, please check"));
|
||||
LOGGER.severe(Common.addTag("[checkPath] the path does not exist, please check"));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -415,4 +471,131 @@ public class Common {
|
|||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Record the current iteration when server is not ready.
|
||||
*
|
||||
* @return int value, the current iteration when server is not ready.
|
||||
*/
|
||||
public static int getIteration() {
|
||||
return iteration;
|
||||
}
|
||||
|
||||
/**
|
||||
* Determine whether to conduct https communication according to the domain name set by user.
|
||||
*
|
||||
* @return boolean value, true means conducting https communication, false means conducting http communication.
|
||||
*/
|
||||
public static boolean isHttps() {
|
||||
return isHttps;
|
||||
}
|
||||
|
||||
public static void setIsHttps(String tag) {
|
||||
if ("https".equals(tag)) {
|
||||
LOGGER.info(Common.addTag("conducting https communication"));
|
||||
Common.isHttps = true;
|
||||
} else if ("http".equals(tag)) {
|
||||
LOGGER.info(Common.addTag("conducting http communication"));
|
||||
Common.isHttps = false;
|
||||
} else {
|
||||
LOGGER.info(Common.addTag("The domain header set by the user is incorrect, please check"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Initialization session.
|
||||
*
|
||||
* @return the status code in client.
|
||||
*/
|
||||
public static FLClientStatus initSession(String modelPath) {
|
||||
FLParameter flParameter = FLParameter.getInstance();
|
||||
LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
if (Common.checkFLName(flParameter.getFlName())) {
|
||||
return deprecatedInitSession();
|
||||
}
|
||||
Status tag;
|
||||
LOGGER.info(Common.addTag("==========Loading model, " + modelPath + " Create " +
|
||||
" Session============="));
|
||||
Client client = ClientManager.getClient(flParameter.getFlName());
|
||||
tag = client.initSessionAndInputs(modelPath, localFLParameter.getMsConfig());
|
||||
if (!Status.SUCCESS.equals(tag)) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return " +
|
||||
"is -1"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Free session.
|
||||
*/
|
||||
protected static void freeSession() {
|
||||
FLParameter flParameter = FLParameter.getInstance();
|
||||
if (Common.checkFLName(flParameter.getFlName())) {
|
||||
deprecatedFreeSession();
|
||||
} else {
|
||||
LOGGER.info(Common.addTag("===========free session============="));
|
||||
Client client = ClientManager.getClient(flParameter.getFlName());
|
||||
client.free();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialization session.
|
||||
*
|
||||
* @return the status code in client.
|
||||
*/
|
||||
private static FLClientStatus deprecatedInitSession() {
|
||||
FLParameter flParameter = FLParameter.getInstance();
|
||||
int tag = 0;
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " +
|
||||
"Train Session============="));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return " +
|
||||
"is -1"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " " +
|
||||
"Create inference Session============="));
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " +
|
||||
"Train Session============="));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||
}
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is " +
|
||||
"-1"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Free session.
|
||||
*/
|
||||
private static void deprecatedFreeSession() {
|
||||
FLParameter flParameter = FLParameter.getInstance();
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
LOGGER.info(Common.addTag("===========free train session============="));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
SessionUtil.free(alTrainBert.getTrainSession());
|
||||
if (!flParameter.getTestDataset().equals("null")) {
|
||||
LOGGER.info(Common.addTag("===========free inference session============="));
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
SessionUtil.free(alInferBert.getTrainSession());
|
||||
}
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("===========free session============="));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
SessionUtil.free(trainLenet.getTrainSession());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
package com.mindspore.flclient;
|
||||
|
||||
import static com.mindspore.flclient.FLParameter.TIME_OUT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.ANDROID;
|
||||
import static com.mindspore.flclient.LocalFLParameter.X86;
|
||||
|
||||
import okhttp3.Call;
|
||||
import okhttp3.Callback;
|
||||
|
@ -49,12 +51,16 @@ import javax.net.ssl.X509TrustManager;
|
|||
*/
|
||||
public class FLCommunication implements IFLCommunication {
|
||||
private static int timeOut;
|
||||
private static boolean ifCertificateVerify = false;
|
||||
private static String sslProtocol;
|
||||
private static String env;
|
||||
private static SSLSocketFactory sslSocketFactory;
|
||||
private static X509TrustManager x509TrustManager;
|
||||
private static final MediaType MEDIA_TYPE_JSON = MediaType.parse("applicatiom/json;charset=utf-8");
|
||||
private static final Logger LOGGER = Logger.getLogger(FLCommunication.class.toString());
|
||||
private static volatile FLCommunication communication;
|
||||
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private OkHttpClient client;
|
||||
|
||||
private FLCommunication() {
|
||||
|
@ -63,7 +69,14 @@ public class FLCommunication implements IFLCommunication {
|
|||
} else {
|
||||
timeOut = TIME_OUT;
|
||||
}
|
||||
ifCertificateVerify = flParameter.isUseSSL();
|
||||
sslProtocol = flParameter.getSslProtocol();
|
||||
if (!Common.checkFLName(flParameter.getFlName())) {
|
||||
env = flParameter.getDeployEnv();
|
||||
if (ANDROID.equals(env)) {
|
||||
sslSocketFactory = flParameter.getSslSocketFactory();
|
||||
x509TrustManager = flParameter.getX509TrustManager();
|
||||
}
|
||||
}
|
||||
client = getOkHttpClient();
|
||||
}
|
||||
|
||||
|
@ -85,33 +98,29 @@ public class FLCommunication implements IFLCommunication {
|
|||
}
|
||||
};
|
||||
final TrustManager[] trustAllCerts = new TrustManager[]{trustManager};
|
||||
try {
|
||||
LOGGER.info(Common.addTag("the set timeOut in OkHttpClient: " + timeOut));
|
||||
OkHttpClient.Builder builder = new OkHttpClient.Builder();
|
||||
builder.connectTimeout(timeOut, TimeUnit.SECONDS);
|
||||
builder.writeTimeout(timeOut, TimeUnit.SECONDS);
|
||||
builder.readTimeout(3 * timeOut, TimeUnit.SECONDS);
|
||||
if (ifCertificateVerify) {
|
||||
builder.sslSocketFactory(SSLSocketFactoryTools.getInstance().getmSslSocketFactory(),
|
||||
SSLSocketFactoryTools.getInstance().getmTrustManager());
|
||||
builder.hostnameVerifier(SSLSocketFactoryTools.getInstance().getHostnameVerifier());
|
||||
} else {
|
||||
final SSLContext sslContext = SSLContext.getInstance("TLS");
|
||||
sslContext.init(null, trustAllCerts, Common.getSecureRandom());
|
||||
final SSLSocketFactory sslFactory = sslContext.getSocketFactory();
|
||||
builder.sslSocketFactory(sslFactory, trustManager);
|
||||
LOGGER.info(Common.addTag("the set timeOut in OkHttpClient: " + timeOut));
|
||||
OkHttpClient.Builder builder = new OkHttpClient.Builder();
|
||||
builder.connectTimeout(timeOut, TimeUnit.SECONDS);
|
||||
builder.writeTimeout(timeOut, TimeUnit.SECONDS);
|
||||
builder.readTimeout(3 * timeOut, TimeUnit.SECONDS);
|
||||
if (Common.isHttps()) {
|
||||
if (ANDROID.equals(env)) {
|
||||
builder.sslSocketFactory(sslSocketFactory, x509TrustManager);
|
||||
builder.hostnameVerifier(new HostnameVerifier() {
|
||||
@Override
|
||||
public boolean verify(String arg0, SSLSession arg1) {
|
||||
return true;
|
||||
}
|
||||
});
|
||||
} else {
|
||||
builder.sslSocketFactory(SSLSocketFactoryTools.getInstance().getmSslSocketFactory(),
|
||||
SSLSocketFactoryTools.getInstance().getmTrustManager());
|
||||
builder.hostnameVerifier(SSLSocketFactoryTools.getInstance().getHostnameVerifier());
|
||||
}
|
||||
return builder.build();
|
||||
} catch (NoSuchAlgorithmException | KeyManagementException ex) {
|
||||
LOGGER.severe(Common.addTag("[OkHttpClient] catch NoSuchAlgorithmException or KeyManagementException: " + ex.getMessage()));
|
||||
throw new IllegalArgumentException(ex);
|
||||
} else {
|
||||
LOGGER.info(Common.addTag("conducting http communication, do not need SSLSocketFactoryTools"));
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -33,6 +33,7 @@ public class FLJobResultCallback implements IFLJobResultCallback {
|
|||
* @param iterationSeq Iteration number
|
||||
* @param resultCode Status Code
|
||||
*/
|
||||
@Override
|
||||
public void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode) {
|
||||
LOGGER.info(Common.addTag("[onFlJobIterationFinished] modelName: " + modelName + " iterationSeq: " +
|
||||
iterationSeq + " resultCode: " + resultCode));
|
||||
|
@ -45,6 +46,7 @@ public class FLJobResultCallback implements IFLJobResultCallback {
|
|||
* @param iterationCount total Iteration numbers
|
||||
* @param resultCode Status Code
|
||||
*/
|
||||
@Override
|
||||
public void onFlJobFinished(String modelName, int iterationCount, int resultCode) {
|
||||
LOGGER.info(Common.addTag("[onFlJobFinished] modelName: " + modelName + " iterationCount: " +
|
||||
iterationCount + " resultCode: " + resultCode));
|
||||
|
|
|
@ -22,8 +22,16 @@ import static com.mindspore.flclient.LocalFLParameter.LENET;
|
|||
|
||||
import com.mindspore.flclient.model.AlInferBert;
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.Client;
|
||||
import com.mindspore.flclient.model.ClientManager;
|
||||
import com.mindspore.flclient.model.CommonUtils;
|
||||
import com.mindspore.flclient.model.RunType;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.Status;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import com.mindspore.flclient.pki.PkiBean;
|
||||
import com.mindspore.flclient.pki.PkiUtil;
|
||||
import com.mindspore.lite.MSTensor;
|
||||
|
||||
import mindspore.schema.CipherPublicParams;
|
||||
import mindspore.schema.FLPlan;
|
||||
|
@ -36,6 +44,7 @@ import java.io.IOException;
|
|||
import java.nio.ByteBuffer;
|
||||
import java.util.Date;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.TreeMap;
|
||||
import java.util.logging.Logger;
|
||||
|
@ -54,7 +63,7 @@ public class FLLiteClient {
|
|||
private double dpNormClipAdapt = 0.05d;
|
||||
private FLCommunication flCommunication;
|
||||
private FLClientStatus status;
|
||||
private int retCode;
|
||||
private int retCode = ResponseCode.RequestError;
|
||||
private int iterations = 1;
|
||||
private int epochs = 1;
|
||||
private int batchSize = 16;
|
||||
|
@ -68,12 +77,15 @@ public class FLLiteClient {
|
|||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private SecureProtocol secureProtocol = new SecureProtocol();
|
||||
private String nextRequestTime;
|
||||
private Client client;
|
||||
private Map<String, float[]> oldFeatureMap;
|
||||
|
||||
/**
|
||||
* Defining a constructor of teh class FLLiteClient.
|
||||
*/
|
||||
public FLLiteClient() {
|
||||
flCommunication = FLCommunication.getInstance();
|
||||
client = ClientManager.getClient(flParameter.getFlName());
|
||||
}
|
||||
|
||||
private int setGlobalParameters(ResponseFLJob flJob) {
|
||||
|
@ -87,22 +99,15 @@ public class FLLiteClient {
|
|||
batchSize = flPlan.miniBatch();
|
||||
String serverMod = flPlan.serverMode();
|
||||
localFLParameter.setServerMod(serverMod);
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <serverMod> from server: " + serverMod));
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for AlTrainBert: " + batchSize));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
alTrainBert.setBatchSize(batchSize);
|
||||
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for TrainLenet: " + batchSize));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
trainLenet.setBatchSize(batchSize);
|
||||
if (Common.checkFLName(flParameter.getFlName())) {
|
||||
deprecatedSetBatchSize(batchSize);
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the ServerMod returned from server is not valid"));
|
||||
return -1;
|
||||
LOGGER.info(Common.addTag("[startFLJob] not set <batchSize> for client: " + batchSize));
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <iterations> from server: " + iterations));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <epochs> from server: " + epochs));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <batchSize> from server: " + batchSize));
|
||||
LOGGER.info(Common.addTag("[startFLJob] the GlobalParameter <serverMod> from server: " + serverMod));
|
||||
LOGGER.info(Common.addTag("[startFLJob] the GlobalParameter <iterations> from server: " + iterations));
|
||||
LOGGER.info(Common.addTag("[startFLJob] the GlobalParameter <epochs> from server: " + epochs));
|
||||
LOGGER.info(Common.addTag("[startFLJob] the GlobalParameter <batchSize> from server: " + batchSize));
|
||||
CipherPublicParams cipherPublicParams = flPlan.cipher();
|
||||
if (cipherPublicParams == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the cipherPublicParams returned from server is null"));
|
||||
|
@ -232,7 +237,12 @@ public class FLLiteClient {
|
|||
StartFLJob startFLJob = StartFLJob.getInstance();
|
||||
Date date = new Date();
|
||||
long time = date.getTime();
|
||||
byte[] msg = startFLJob.getRequestStartFLJob(trainDataSize, iteration, time);
|
||||
|
||||
PkiBean pkiBean = null;
|
||||
if (flParameter.isPkiVerify()) {
|
||||
pkiBean = PkiUtil.genPkiBean(flParameter.getClientID(), time);
|
||||
}
|
||||
byte[] msg = startFLJob.getRequestStartFLJob(trainDataSize, iteration, time, pkiBean);
|
||||
try {
|
||||
long start = Common.startTime("single startFLJob");
|
||||
LOGGER.info(Common.addTag("[startFLJob] the request message length: " + msg.length));
|
||||
|
@ -243,6 +253,7 @@ public class FLLiteClient {
|
|||
status = FLClientStatus.RESTART;
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
retCode = ResponseCode.OutOfTime;
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] the response message length: " + message.length));
|
||||
|
@ -250,12 +261,9 @@ public class FLLiteClient {
|
|||
ByteBuffer buffer = ByteBuffer.wrap(message);
|
||||
ResponseFLJob responseDataBuf = ResponseFLJob.getRootAsResponseFLJob(buffer);
|
||||
status = judgeStartFLJob(startFLJob, responseDataBuf);
|
||||
retCode = responseDataBuf.retcode();
|
||||
} catch (IOException e) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in StartFLJob: catch IOException: " +
|
||||
e.getMessage()));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
failed("[startFLJob] unsolved error code in StartFLJob: catch IOException: " + e.getMessage(),
|
||||
ResponseCode.RequestError);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
@ -263,12 +271,13 @@ public class FLLiteClient {
|
|||
private FLClientStatus judgeStartFLJob(StartFLJob startFLJob, ResponseFLJob responseDataBuf) {
|
||||
iteration = responseDataBuf.iteration();
|
||||
FLClientStatus response = startFLJob.doResponse(responseDataBuf);
|
||||
retCode = startFLJob.getRetCode();
|
||||
status = response;
|
||||
switch (response) {
|
||||
case SUCCESS:
|
||||
LOGGER.info(Common.addTag("[startFLJob] startFLJob success"));
|
||||
featureSize = startFLJob.getFeatureSize();
|
||||
secureProtocol.setEncryptFeatureName(startFLJob.getEncryptFeatureName());
|
||||
secureProtocol.setUpdateFeatureName(startFLJob.getUpdateFeatureName());
|
||||
LOGGER.info(Common.addTag("[startFLJob] ***the feature size get in ResponseFLJob***: " + featureSize));
|
||||
int tag = setGlobalParameters(responseDataBuf);
|
||||
if (tag == -1) {
|
||||
|
@ -297,38 +306,37 @@ public class FLLiteClient {
|
|||
return status;
|
||||
}
|
||||
|
||||
private FLClientStatus trainLoop() {
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
status = Common.initSession(flParameter.getTrainModelPath());
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
LOGGER.info(Common.addTag("[train] train in " + flParameter.getFlName()));
|
||||
Status tag = client.trainModel(epochs);
|
||||
if (!Status.SUCCESS.equals(tag)) {
|
||||
failed("[train] unsolved error code in <client.trainModel>", ResponseCode.RequestError);
|
||||
}
|
||||
client.saveModel(flParameter.getTrainModelPath());
|
||||
Common.freeSession();
|
||||
return status;
|
||||
}
|
||||
|
||||
/**
|
||||
* Define the training process.
|
||||
*
|
||||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public FLClientStatus localTrain() {
|
||||
|
||||
LOGGER.info(Common.addTag("[train] ====================================global train epoch " + iteration +
|
||||
"===================================="));
|
||||
status = FLClientStatus.SUCCESS;
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
LOGGER.info(Common.addTag("[train] train in albert"));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
int tag = alTrainBert.trainModel(flParameter.getTrainModelPath(), epochs);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[train] unsolved error code in <alTrainBert.trainModel>"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
}
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("[train] train in lenet"));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
int tag = trainLenet.trainModel(flParameter.getTrainModelPath(), epochs);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[train] unsolved error code in <trainLenet.trainModel>"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
}
|
||||
if (Common.checkFLName(flParameter.getFlName())) {
|
||||
status = deprecatedTrainLoop();
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[train] the flName is not valid"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
status = trainLoop();
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
@ -357,6 +365,7 @@ public class FLLiteClient {
|
|||
status = FLClientStatus.RESTART;
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
retCode = ResponseCode.OutOfTime;
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[updateModel] the response message length: " + message.length));
|
||||
|
@ -364,16 +373,14 @@ public class FLLiteClient {
|
|||
ByteBuffer debugBuffer = ByteBuffer.wrap(message);
|
||||
ResponseUpdateModel responseDataBuf = ResponseUpdateModel.getRootAsResponseUpdateModel(debugBuffer);
|
||||
status = updateModelBuf.doResponse(responseDataBuf);
|
||||
retCode = responseDataBuf.retcode();
|
||||
retCode = updateModelBuf.getRetCode();
|
||||
if (status == FLClientStatus.RESTART) {
|
||||
nextRequestTime = responseDataBuf.nextReqTime();
|
||||
}
|
||||
LOGGER.info(Common.addTag("[updateModel] get response from server ok!"));
|
||||
} catch (IOException e) {
|
||||
LOGGER.severe(Common.addTag("[updateModel] unsolved error code in updateModel: catch IOException: " +
|
||||
e.getMessage()));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
failed("[updateModel] unsolved error code in updateModel: catch IOException: " + e.getMessage(),
|
||||
ResponseCode.RequestError);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
@ -396,6 +403,7 @@ public class FLLiteClient {
|
|||
LOGGER.info(Common.addTag("[getModel] the server is not ready now, need wait some time and request " +
|
||||
"again"));
|
||||
status = FLClientStatus.WAIT;
|
||||
retCode = ResponseCode.SucNotReady;
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] the response message length: " + message.length));
|
||||
|
@ -404,19 +412,35 @@ public class FLLiteClient {
|
|||
ByteBuffer debugBuffer = ByteBuffer.wrap(message);
|
||||
ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer);
|
||||
status = getModelBuf.doResponse(responseDataBuf);
|
||||
retCode = responseDataBuf.retcode();
|
||||
retCode = getModelBuf.getRetCode();
|
||||
if (status == FLClientStatus.RESTART) {
|
||||
nextRequestTime = responseDataBuf.timestamp();
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] get response from server ok!"));
|
||||
} catch (IOException e) {
|
||||
LOGGER.severe(Common.addTag("[getModel] un sloved error code: catch IOException: " + e.getMessage()));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
failed("[getModel] un sloved error code: catch IOException: " + e.getMessage(), ResponseCode.RequestError);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
private Map<String, float[]> getFeatureMap() {
|
||||
Map<String, float[]> featureMap = new HashMap<>();
|
||||
if (Common.checkFLName(flParameter.getFlName())) {
|
||||
featureMap = deprecatedGetFeatureMap();
|
||||
return featureMap;
|
||||
}
|
||||
status = Common.initSession(flParameter.getTrainModelPath());
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
Common.freeSession();
|
||||
retCode = ResponseCode.RequestError;
|
||||
return new HashMap<>();
|
||||
}
|
||||
List<MSTensor> features = client.getFeatures();
|
||||
featureMap = CommonUtils.convertTensorToFeatures(features);
|
||||
Common.freeSession();
|
||||
return featureMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtain the weight of the model before training.
|
||||
*
|
||||
|
@ -456,6 +480,58 @@ public class FLLiteClient {
|
|||
return mapBeforeTrain;
|
||||
}
|
||||
|
||||
private void getOldFeatureMap() {
|
||||
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel();
|
||||
if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
|
||||
Map<String, float[]> featureMap = getFeatureMap();
|
||||
oldFeatureMap = getOldMapCopy(featureMap);
|
||||
}
|
||||
}
|
||||
|
||||
public void updateDpNormClip() {
|
||||
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel();
|
||||
if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
|
||||
Map<String, float[]> fedFeatureMap = getFeatureMap();
|
||||
float fedWeightUpdateNorm = calWeightUpdateNorm(oldFeatureMap, fedFeatureMap);
|
||||
if (fedWeightUpdateNorm == -1) {
|
||||
LOGGER.severe(Common.addTag("[updateDpNormClip] the returned value fedWeightUpdateNorm is not valid: " +
|
||||
"-1, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
LOGGER.info(Common.addTag("[DP] L2-norm of weights' average update is: " + fedWeightUpdateNorm));
|
||||
float newNormCLip = (float) getDpNormClipFactor() * fedWeightUpdateNorm;
|
||||
if (iteration == 1) {
|
||||
setDpNormClipAdapt(newNormCLip);
|
||||
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
|
||||
} else {
|
||||
if (newNormCLip < getDpNormClipAdapt()) {
|
||||
setDpNormClipAdapt(newNormCLip);
|
||||
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
|
||||
}
|
||||
}
|
||||
LOGGER.info(Common.addTag("[DP] Adaptive dpNormClip is: " + getDpNormClipAdapt()));
|
||||
}
|
||||
}
|
||||
|
||||
private float calWeightUpdateNorm(Map<String, float[]> originalData, Map<String, float[]> newData) {
|
||||
float updateL2Norm = 0f;
|
||||
for (String key : originalData.keySet()) {
|
||||
float[] data = originalData.get(key);
|
||||
float[] dataAfterUpdate = newData.get(key);
|
||||
for (int j = 0; j < data.length; j++) {
|
||||
if (j >= dataAfterUpdate.length) {
|
||||
LOGGER.severe("[calWeightUpdateNorm] the index j is out of range for array dataAfterUpdate, " +
|
||||
"please check");
|
||||
return -1;
|
||||
}
|
||||
float updateData = data[j] - dataAfterUpdate[j];
|
||||
updateL2Norm += updateData * updateData;
|
||||
}
|
||||
}
|
||||
updateL2Norm = (float) Math.sqrt(updateL2Norm);
|
||||
return updateL2Norm;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtain pairwise mask and individual mask.
|
||||
*
|
||||
|
@ -477,16 +553,14 @@ public class FLLiteClient {
|
|||
localFLParameter.getEncryptLevel().toString() + "> : " + curStatus));
|
||||
return curStatus;
|
||||
case DP_ENCRYPT:
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
// get the feature map before train
|
||||
getOldFeatureMap();
|
||||
if (oldFeatureMap.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[Encrypt] the return map in getOldFeatureMapis empty "));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
Map<String, float[]> copyMap = getOldMapCopy(map);
|
||||
curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClipAdapt, copyMap);
|
||||
curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClipAdapt, oldFeatureMap);
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
if (curStatus != FLClientStatus.SUCCESS) {
|
||||
LOGGER.info(Common.addTag("---Differential privacy init failed---"));
|
||||
|
@ -537,85 +611,50 @@ public class FLLiteClient {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
private FLClientStatus evaluateLoop() {
|
||||
client.free();
|
||||
status = FLClientStatus.SUCCESS;
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
|
||||
float acc = 0;
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
|
||||
client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig());
|
||||
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath()));
|
||||
acc = client.evalModel();
|
||||
} else {
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
|
||||
client.initSessionAndInputs(flParameter.getTrainModelPath(), localFLParameter.getMsConfig());
|
||||
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getTrainModelPath()));
|
||||
acc = client.evalModel();
|
||||
}
|
||||
if (Float.isNaN(acc)) {
|
||||
failed("[evaluate] unsolved error code in <evalModel>: the return acc is NAN", ResponseCode.RequestError);
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
|
||||
return status;
|
||||
}
|
||||
|
||||
private void failed(String log, int retCode) {
|
||||
LOGGER.severe(Common.addTag(log));
|
||||
status = FLClientStatus.FAILED;
|
||||
this.retCode = retCode;
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluate model after getting model from server.
|
||||
*
|
||||
* @return the status code in client.
|
||||
*/
|
||||
public FLClientStatus evaluateModel() {
|
||||
status = FLClientStatus.SUCCESS;
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
LOGGER.info(Common.addTag("===================================evaluate model after getting model from " +
|
||||
"server==================================="));
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
float acc = 0;
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(),
|
||||
flParameter.getIdsFile(), true);
|
||||
if (dataSize <= 0) {
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alInferBert.initDataSet>: the " +
|
||||
"return dataSize<=0"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
acc = alInferBert.evalModel();
|
||||
} else {
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
int dataSize = alTrainBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(),
|
||||
flParameter.getIdsFile());
|
||||
if (dataSize <= 0) {
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alTrainBert.initDataSet>: the " +
|
||||
"return dataSize<=0"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
acc = alTrainBert.evalModel();
|
||||
}
|
||||
if (Float.isNaN(acc)) {
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <evalModel>: the return acc is NAN"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " +
|
||||
flParameter.getTestDataset() + " vocabFile: " + flParameter.getVocabFile() +
|
||||
" idsFile: " + flParameter.getIdsFile()));
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
if (flParameter.getTestDataset().split(",").length < 2) {
|
||||
LOGGER.severe(Common.addTag("[evaluate] the set testDataPath for lenet is not valid, should be the " +
|
||||
"format of <data.bin,label.bin> "));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
int dataSize = trainLenet.initDataSet(flParameter.getTestDataset().split(",")[0],
|
||||
flParameter.getTestDataset().split(",")[1]);
|
||||
if (dataSize <= 0) {
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.initDataSet>: the return " +
|
||||
"dataSize<=0"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
float acc = trainLenet.evalModel();
|
||||
if (Float.isNaN(acc)) {
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.evalModel>: the return acc" +
|
||||
" is NAN"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " +
|
||||
flParameter.getTestDataset().split(",")[0] + " labelPath: " +
|
||||
flParameter.getTestDataset().split(",")[1]));
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
|
||||
if (Common.checkFLName(flParameter.getFlName())) {
|
||||
status = deprecatedEvaluateLoop();
|
||||
} else {
|
||||
status = evaluateLoop();
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
@ -623,10 +662,43 @@ public class FLLiteClient {
|
|||
/**
|
||||
* Set date path.
|
||||
*
|
||||
* @param dataPath, train or test dataset and label set.
|
||||
* @return date size.
|
||||
*/
|
||||
public int setInput(String dataPath) {
|
||||
public int setInput() {
|
||||
int dataSize = 0;
|
||||
if (Common.checkFLName(flParameter.getFlName())) {
|
||||
dataSize = deprecatedSetInput(flParameter.getTrainDataset());
|
||||
return dataSize;
|
||||
}
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
LOGGER.info(Common.addTag("==========set input==========="));
|
||||
|
||||
// train
|
||||
dataSize = client.initDataSets(flParameter.getDataMap()).get(RunType.TRAINMODE);
|
||||
if (dataSize <= 0) {
|
||||
retCode = ResponseCode.RequestError;
|
||||
return -1;
|
||||
}
|
||||
return dataSize;
|
||||
}
|
||||
|
||||
private int deprecatedSetBatchSize(int batchSize) {
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for AlTrainBert: " + batchSize));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
alTrainBert.setBatchSize(batchSize);
|
||||
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for TrainLenet: " + batchSize));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
trainLenet.setBatchSize(batchSize);
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the ServerMod returned from server is not valid"));
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
private int deprecatedSetInput(String dataPath) {
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
LOGGER.info(Common.addTag("==========set input==========="));
|
||||
int dataSize = 0;
|
||||
|
@ -653,61 +725,139 @@ public class FLLiteClient {
|
|||
return dataSize;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialization session.
|
||||
*
|
||||
* @return the status code in client.
|
||||
*/
|
||||
public FLClientStatus initSession() {
|
||||
int tag = 0;
|
||||
private FLClientStatus deprecatedTrainLoop() {
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
status = Common.initSession(flParameter.getTrainModelPath());
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
status = FLClientStatus.SUCCESS;
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " +
|
||||
"Train Session============="));
|
||||
LOGGER.info(Common.addTag("[train] train in albert"));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||
int tag = alTrainBert.trainModel(flParameter.getTrainModelPath(), epochs);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return " +
|
||||
"is -1"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
failed("[train] unsolved error code in <alTrainBert.trainModel>", ResponseCode.RequestError);
|
||||
}
|
||||
LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " " +
|
||||
"Create inference Session============="));
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " +
|
||||
"Train Session============="));
|
||||
LOGGER.info(Common.addTag("[train] train in lenet"));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||
int tag = trainLenet.trainModel(flParameter.getTrainModelPath(), epochs);
|
||||
if (tag == -1) {
|
||||
failed("[train] unsolved error code in <trainLenet.trainModel>", ResponseCode.RequestError);
|
||||
}
|
||||
} else {
|
||||
failed("[train] the flName is not valid", ResponseCode.RequestError);
|
||||
}
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is " +
|
||||
"-1"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
Common.freeSession();
|
||||
return status;
|
||||
}
|
||||
|
||||
/**
|
||||
* Free session.
|
||||
*/
|
||||
protected void freeSession() {
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
LOGGER.info(Common.addTag("===========free train session============="));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
SessionUtil.free(alTrainBert.getTrainSession());
|
||||
if (!flParameter.getTestDataset().equals("null")) {
|
||||
LOGGER.info(Common.addTag("===========free inference session============="));
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
SessionUtil.free(alInferBert.getTrainSession());
|
||||
}
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("===========free session============="));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
SessionUtil.free(trainLenet.getTrainSession());
|
||||
private Map<String, float[]> deprecatedGetFeatureMap() {
|
||||
status = Common.initSession(flParameter.getTrainModelPath());
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
Common.freeSession();
|
||||
retCode = ResponseCode.RequestError;
|
||||
return new HashMap<>();
|
||||
}
|
||||
Map<String, float[]> featureMap = new HashMap<>();
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
}
|
||||
Common.freeSession();
|
||||
return featureMap;
|
||||
}
|
||||
|
||||
|
||||
private FLClientStatus deprecatedEvaluateLoop() {
|
||||
status = FLClientStatus.SUCCESS;
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
float acc = 0;
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
int tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
|
||||
if (tag == -1) {
|
||||
failed("[evaluate] unsolved error code in <initSessionAndInputs>: the return is -1",
|
||||
ResponseCode.RequestError);
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(),
|
||||
flParameter.getIdsFile(), true);
|
||||
if (dataSize <= 0) {
|
||||
failed("[evaluate] unsolved error code in <alInferBert.initDataSet>: the return dataSize<=0",
|
||||
ResponseCode.RequestError);
|
||||
return status;
|
||||
}
|
||||
acc = alInferBert.evalModel();
|
||||
SessionUtil.free(alInferBert.getTrainSession());
|
||||
} else {
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
int tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), false);
|
||||
if (tag == -1) {
|
||||
failed("[evaluate] unsolved error code in <initSessionAndInputs>: the return is -1",
|
||||
ResponseCode.RequestError);
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
int dataSize = alTrainBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(),
|
||||
flParameter.getIdsFile());
|
||||
if (dataSize <= 0) {
|
||||
failed("[evaluate] unsolved error code in <alTrainBert.initDataSet>: the return dataSize<=0",
|
||||
ResponseCode.RequestError);
|
||||
return status;
|
||||
}
|
||||
acc = alTrainBert.evalModel();
|
||||
SessionUtil.free(alTrainBert.getTrainSession());
|
||||
}
|
||||
if (Float.isNaN(acc)) {
|
||||
failed("[evaluate] unsolved error code in <evalModel>: the return acc is NAN",
|
||||
ResponseCode.RequestError);
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " +
|
||||
flParameter.getTestDataset() + " vocabFile: " + flParameter.getVocabFile() +
|
||||
" idsFile: " + flParameter.getIdsFile()));
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
if (flParameter.getTestDataset().split(",").length < 2) {
|
||||
failed("[evaluate] the set testDataPath for lenet is not valid, should be the format of <data.bin," +
|
||||
"label.bin>", ResponseCode.RequestError);
|
||||
return status;
|
||||
}
|
||||
int tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||
if (tag == -1) {
|
||||
failed("[evaluate] unsolved error code in <initSessionAndInputs>: the return is -1",
|
||||
ResponseCode.RequestError);
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
int dataSize = trainLenet.initDataSet(flParameter.getTestDataset().split(",")[0],
|
||||
flParameter.getTestDataset().split(",")[1]);
|
||||
if (dataSize <= 0) {
|
||||
failed("[evaluate] unsolved error code in <trainLenet.initDataSet>: the return dataSize<=0",
|
||||
ResponseCode.RequestError);
|
||||
return status;
|
||||
}
|
||||
float acc = trainLenet.evalModel();
|
||||
SessionUtil.free(trainLenet.getTrainSession());
|
||||
if (Float.isNaN(acc)) {
|
||||
failed("[evaluate] unsolved error code in <trainLenet.evalModel>: the return acc is NAN",
|
||||
ResponseCode.RequestError);
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " +
|
||||
flParameter.getTestDataset().split(",")[0] + " labelPath: " +
|
||||
flParameter.getTestDataset().split(",")[1]));
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
|
||||
}
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,10 +18,19 @@ package com.mindspore.flclient;
|
|||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
|
||||
import com.mindspore.flclient.model.RunType;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import javax.net.ssl.SSLSocketFactory;
|
||||
import javax.net.ssl.X509TrustManager;
|
||||
|
||||
/**
|
||||
* Defines global parameters used during federated learning and these parameters are provided for users to set.
|
||||
*
|
||||
|
@ -41,21 +50,39 @@ public class FLParameter {
|
|||
public static final int SLEEP_TIME = 1000;
|
||||
private static volatile FLParameter flParameter;
|
||||
|
||||
private String deployEnv;
|
||||
private String domainName;
|
||||
private String certPath;
|
||||
|
||||
private SSLSocketFactory sslSocketFactory;
|
||||
private X509TrustManager x509TrustManager;
|
||||
private IFLJobResultCallback iflJobResultCallback = new FLJobResultCallback();
|
||||
|
||||
private String trainDataset;
|
||||
private String vocabFile = "null";
|
||||
private String idsFile = "null";
|
||||
private String testDataset = "null";
|
||||
private boolean useSSL = false;
|
||||
private String flName;
|
||||
private String trainModelPath;
|
||||
private String inferModelPath;
|
||||
private String sslProtocol = "TLSv1.2";
|
||||
private String clientID;
|
||||
private boolean useSSL = false;
|
||||
private int timeOut;
|
||||
private int sleepTime;
|
||||
private boolean ifUseElb = false;
|
||||
private int serverNum = 1;
|
||||
private boolean ifPkiVerify = false;
|
||||
private String equipCrlPath = "null";
|
||||
private long validIterInterval = 3600000L;
|
||||
private int threadNum = 1;
|
||||
private BindMode cpuBindMode = BindMode.NOT_BINDING_CORE;
|
||||
|
||||
private List<String> trainWeightName = new ArrayList<>();
|
||||
private List<String> inferWeightName = new ArrayList<>();
|
||||
private Map<RunType, List<String>> dataMap = new HashMap<>();
|
||||
private ServerMod serverMod;
|
||||
private int batchSize;
|
||||
|
||||
private FLParameter() {
|
||||
clientID = UUID.randomUUID().toString();
|
||||
|
@ -79,10 +106,28 @@ public class FLParameter {
|
|||
return localRef;
|
||||
}
|
||||
|
||||
public String getDeployEnv() {
|
||||
if (deployEnv == null || deployEnv.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <env> is null, please set it before using"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return deployEnv;
|
||||
}
|
||||
|
||||
public void setDeployEnv(String env) {
|
||||
if (Common.checkEnv(env)) {
|
||||
this.deployEnv = env;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <env> is not in envTrustList: x86, android, " +
|
||||
"please check it before setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
||||
public String getDomainName() {
|
||||
if (domainName == null || domainName.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <domainName> is null or empty, please set it " +
|
||||
"before use"));
|
||||
"before using"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return domainName;
|
||||
|
@ -91,16 +136,33 @@ public class FLParameter {
|
|||
public void setDomainName(String domainName) {
|
||||
if (domainName == null || domainName.isEmpty() || (!("https".equals(domainName.split(":")[0]) || "http".equals(domainName.split(":")[0])))) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <domainName> is not valid, it should be like " +
|
||||
"as https://...... or http://......, please check it before set"));
|
||||
"as https://...... or http://......, please check it before setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.domainName = domainName;
|
||||
}
|
||||
|
||||
public String getClientID() {
|
||||
if (clientID == null || clientID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <clientID> is null or empty, please check"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return clientID;
|
||||
}
|
||||
|
||||
public void setClientID(String clientID) {
|
||||
if (clientID == null || clientID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <clientID> is null or empty, please check " +
|
||||
"before setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.clientID = clientID;
|
||||
}
|
||||
|
||||
public String getCertPath() {
|
||||
if (certPath == null || certPath.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is null or empty, please set it " +
|
||||
"before use"));
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is null or empty, the <certPath> " +
|
||||
"must be set when conducting https communication, please set it by FLParameter.setCertPath()"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return certPath;
|
||||
|
@ -111,28 +173,70 @@ public class FLParameter {
|
|||
if (Common.checkPath(realCertPath)) {
|
||||
this.certPath = realCertPath;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is not exist, please check it " +
|
||||
"before set"));
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> does not exist, it must be a valid" +
|
||||
" path when conducting https communication, please check it before setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
||||
public SSLSocketFactory getSslSocketFactory() {
|
||||
if (sslSocketFactory == null) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <sslSocketFactory> is null, the " +
|
||||
"<sslSocketFactory> must be set when the deployEnv being \"android\", please set it by " +
|
||||
"FLParameter.setSslSocketFactory()"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return sslSocketFactory;
|
||||
}
|
||||
|
||||
public void setSslSocketFactory(SSLSocketFactory sslSocketFactory) {
|
||||
this.sslSocketFactory = sslSocketFactory;
|
||||
}
|
||||
|
||||
public X509TrustManager getX509TrustManager() {
|
||||
if (x509TrustManager == null) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <x509TrustManager> is null, the " +
|
||||
"<x509TrustManager> must be set when the deployEnv being \"android\", please set it by " +
|
||||
"FLParameter.setX509TrustManager()"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return x509TrustManager;
|
||||
}
|
||||
|
||||
public void setX509TrustManager(X509TrustManager x509TrustManager) {
|
||||
this.x509TrustManager = x509TrustManager;
|
||||
}
|
||||
|
||||
public IFLJobResultCallback getIflJobResultCallback() {
|
||||
if (iflJobResultCallback == null) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <iflJobResultCallback> is null, please set it" +
|
||||
" before using"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return iflJobResultCallback;
|
||||
}
|
||||
|
||||
public void setIflJobResultCallback(IFLJobResultCallback iflJobResultCallback) {
|
||||
this.iflJobResultCallback = iflJobResultCallback;
|
||||
}
|
||||
|
||||
public String getTrainDataset() {
|
||||
if (trainDataset == null || trainDataset.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is null or empty, please set " +
|
||||
"it before use"));
|
||||
"it before using"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return trainDataset;
|
||||
}
|
||||
|
||||
public void setTrainDataset(String trainDataset) {
|
||||
LOGGER.warning(Common.addTag(Common.LOG_DEPRECATED));
|
||||
String realTrainDataset = Common.getRealPath(trainDataset);
|
||||
if (Common.checkPath(realTrainDataset)) {
|
||||
this.trainDataset = realTrainDataset;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is not exist, please check it " +
|
||||
"before set"));
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> does not exist, please check " +
|
||||
"it before setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
@ -140,38 +244,41 @@ public class FLParameter {
|
|||
public String getVocabFile() {
|
||||
if ("null".equals(vocabFile) && ALBERT.equals(flName)) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is null, please set it before " +
|
||||
"use"));
|
||||
"using"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return vocabFile;
|
||||
}
|
||||
|
||||
public void setVocabFile(String vocabFile) {
|
||||
LOGGER.warning(Common.addTag(Common.LOG_DEPRECATED));
|
||||
String realVocabFile = Common.getRealPath(vocabFile);
|
||||
if (Common.checkPath(realVocabFile)) {
|
||||
this.vocabFile = realVocabFile;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is not exist, please check it " +
|
||||
"before set"));
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> does not exist, please check it " +
|
||||
"before setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
||||
public String getIdsFile() {
|
||||
if ("null".equals(idsFile) && ALBERT.equals(flName)) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is null, please set it before use"));
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is null, please set it before " +
|
||||
"using"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return idsFile;
|
||||
}
|
||||
|
||||
public void setIdsFile(String idsFile) {
|
||||
LOGGER.warning(Common.addTag(Common.LOG_DEPRECATED));
|
||||
String realIdsFile = Common.getRealPath(idsFile);
|
||||
if (Common.checkPath(realIdsFile)) {
|
||||
this.idsFile = realIdsFile;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is not exist, please check it " +
|
||||
"before set"));
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> does not exist, please check it " +
|
||||
"before setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
@ -181,12 +288,13 @@ public class FLParameter {
|
|||
}
|
||||
|
||||
public void setTestDataset(String testDataset) {
|
||||
LOGGER.warning(Common.addTag(Common.LOG_DEPRECATED));
|
||||
String realTestDataset = Common.getRealPath(testDataset);
|
||||
if (Common.checkPath(realTestDataset)) {
|
||||
this.testDataset = realTestDataset;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <testDataset> is not exist, please check it " +
|
||||
"before set"));
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <testDataset> does not exist, please check it" +
|
||||
" before setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
@ -194,27 +302,20 @@ public class FLParameter {
|
|||
public String getFlName() {
|
||||
if (flName == null || flName.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is null or empty, please set it " +
|
||||
"before use"));
|
||||
"before using"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return flName;
|
||||
}
|
||||
|
||||
public void setFlName(String flName) {
|
||||
if (Common.checkFLName(flName)) {
|
||||
this.flName = flName;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is not in FL_NAME_TRUST_LIST: " +
|
||||
Arrays.toString(Common.FL_NAME_TRUST_LIST.toArray(new String[0])) + ", please check it before " +
|
||||
"set"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.flName = flName;
|
||||
}
|
||||
|
||||
public String getTrainModelPath() {
|
||||
if (trainModelPath == null || trainModelPath.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is null or empty, please set" +
|
||||
" it before use"));
|
||||
" it before using"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return trainModelPath;
|
||||
|
@ -225,8 +326,8 @@ public class FLParameter {
|
|||
if (Common.checkPath(realTrainModelPath)) {
|
||||
this.trainModelPath = realTrainModelPath;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is not exist, please check " +
|
||||
"it before set"));
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> does not exist, please " +
|
||||
"check it before setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
@ -234,7 +335,7 @@ public class FLParameter {
|
|||
public String getInferModelPath() {
|
||||
if (inferModelPath == null || inferModelPath.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is null or empty, please set" +
|
||||
" it before use"));
|
||||
" it before using"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return inferModelPath;
|
||||
|
@ -245,8 +346,8 @@ public class FLParameter {
|
|||
if (Common.checkPath(realInferModelPath)) {
|
||||
this.inferModelPath = realInferModelPath;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is not exist, please check " +
|
||||
"it before set"));
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> does not exist, please check" +
|
||||
" it before setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
@ -256,9 +357,31 @@ public class FLParameter {
|
|||
}
|
||||
|
||||
public void setUseSSL(boolean useSSL) {
|
||||
LOGGER.warning(Common.addTag("Certificate authentication is required for https communication,this parameter " +
|
||||
"is true by default and no need to set it, " + Common.LOG_DEPRECATED));
|
||||
this.useSSL = useSSL;
|
||||
}
|
||||
|
||||
public String getSslProtocol() {
|
||||
if (sslProtocol == null || sslProtocol.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <sslProtocol> is null or empty, please set it" +
|
||||
" before using"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return sslProtocol;
|
||||
}
|
||||
|
||||
public void setSslProtocol(String sslProtocol) {
|
||||
if (Common.checkSSLProtocol(sslProtocol)) {
|
||||
this.sslProtocol = sslProtocol;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <sslProtocol> is not in sslProtocolTrustList " +
|
||||
": " + Arrays.toString(Common.SSL_PROTOCOL_TRUST_LIST.toArray(new String[0])) + ", please check " +
|
||||
"it before setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
||||
public int getTimeOut() {
|
||||
return timeOut;
|
||||
}
|
||||
|
@ -286,7 +409,7 @@ public class FLParameter {
|
|||
public int getServerNum() {
|
||||
if (serverNum <= 0) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <serverNum> <= 0, it should be > 0, please " +
|
||||
"set it before use"));
|
||||
"set it before using"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return serverNum;
|
||||
|
@ -296,11 +419,159 @@ public class FLParameter {
|
|||
this.serverNum = serverNum;
|
||||
}
|
||||
|
||||
public String getClientID() {
|
||||
if (clientID == null || clientID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <clientID> is null or empty, please check"));
|
||||
public boolean isPkiVerify() {
|
||||
return ifPkiVerify;
|
||||
}
|
||||
|
||||
public void setPkiVerify(boolean ifPkiVerify) {
|
||||
this.ifPkiVerify = ifPkiVerify;
|
||||
}
|
||||
|
||||
public String getEquipCrlPath() {
|
||||
if (equipCrlPath == null || equipCrlPath.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <equipCrlPath> is null or empty, please set " +
|
||||
"it before using"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return clientID;
|
||||
return equipCrlPath;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtains the Valid Iteration Interval set by a user.
|
||||
*
|
||||
* @return the Valid Iteration Interval.
|
||||
*/
|
||||
public long getValidInterval() {
|
||||
if (validIterInterval <= 0) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <validIterInterval> is not valid, please set " +
|
||||
"it as larger than 0."));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return validIterInterval;
|
||||
}
|
||||
|
||||
public void setEquipCrlPath(String certPath) {
|
||||
String realCertPath = Common.getRealPath(certPath);
|
||||
if (Common.checkPath(realCertPath)) {
|
||||
this.equipCrlPath = realCertPath;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <equipCrlPath> does not exist, please check " +
|
||||
"it before setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the Valid Iteration Interval.
|
||||
*
|
||||
* @param validInterval the Valid Iteration Interval.
|
||||
*/
|
||||
public void setValidInterval(long validInterval) {
|
||||
if (validInterval > 0) {
|
||||
this.validIterInterval = validInterval;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <validIterInterval> should be larger than 0, " +
|
||||
"please set it again."));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
||||
public int getThreadNum() {
|
||||
return threadNum;
|
||||
}
|
||||
|
||||
public void setThreadNum(int threadNum) {
|
||||
if (threadNum <= 0) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <threadNum> <= 0, please check it before " +
|
||||
"setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.threadNum = threadNum;
|
||||
}
|
||||
|
||||
public int getCpuBindMode() {
|
||||
LOGGER.info(Common.addTag("[flParameter] the parameter of <cpuBindMode> is: " + cpuBindMode.toString() + " , " +
|
||||
"the NOT_BINDING_CORE means that not binding core, BIND_LARGE_CORE means binding the large core, " +
|
||||
"BIND_MIDDLE_CORE means binding the middle core"));
|
||||
return cpuBindMode.ordinal();
|
||||
}
|
||||
|
||||
public void setCpuBindMode(BindMode cpuBindMode) {
|
||||
this.cpuBindMode = cpuBindMode;
|
||||
}
|
||||
|
||||
public void setHybridWeightName(List<String> hybridWeightName, RunType runType) {
|
||||
if (RunType.TRAINMODE.equals(runType)) {
|
||||
this.trainWeightName = hybridWeightName;
|
||||
} else if (RunType.INFERMODE.equals(runType)) {
|
||||
this.inferWeightName = hybridWeightName;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the variable <runType> can only be set to <RunType.TRAINMODE> " +
|
||||
"or <RunType.INFERMODE>, please check it"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public List<String> getHybridWeightName(RunType runType) {
|
||||
if (RunType.TRAINMODE.equals(runType)) {
|
||||
if (trainWeightName.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainWeightName> is null, please " +
|
||||
"set it before use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return trainWeightName;
|
||||
} else if (RunType.INFERMODE.equals(runType)) {
|
||||
if (inferWeightName.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferWeightName> is null, please " +
|
||||
"set it before use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return inferWeightName;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the variable <runType> can only be set to <RunType.TRAINMODE> " +
|
||||
"or <RunType.INFERMODE>, please check it"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public Map<RunType, List<String>> getDataMap() {
|
||||
if (dataMap.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <dataMaps> is null, please " +
|
||||
"set it before use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return dataMap;
|
||||
}
|
||||
|
||||
public void setDataMap(Map<RunType, List<String>> dataMap) {
|
||||
this.dataMap = dataMap;
|
||||
}
|
||||
|
||||
public ServerMod getServerMod() {
|
||||
if (serverMod == null) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <serverMod> is null, please " +
|
||||
"set it before use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return serverMod;
|
||||
}
|
||||
|
||||
public void setServerMod(ServerMod serverMod) {
|
||||
this.serverMod = serverMod;
|
||||
}
|
||||
|
||||
public int getBatchSize() {
|
||||
return batchSize;
|
||||
}
|
||||
|
||||
public void setBatchSize(int batchSize) {
|
||||
if (batchSize <= 0) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <batchSize> <= 0, please check it before " +
|
||||
"setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.batchSize = batchSize;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,16 +16,17 @@
|
|||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
|
||||
import com.mindspore.flclient.model.AlInferBert;
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.Client;
|
||||
import com.mindspore.flclient.model.ClientManager;
|
||||
import com.mindspore.flclient.model.RunType;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import com.mindspore.flclient.model.Status;
|
||||
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import mindspore.schema.FeatureMap;
|
||||
import mindspore.schema.RequestGetModel;
|
||||
import mindspore.schema.ResponseCode;
|
||||
|
@ -35,6 +36,9 @@ import java.util.ArrayList;
|
|||
import java.util.Date;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
/**
|
||||
* Define the serialization method, handle the response message returned from server for getModel request.
|
||||
*
|
||||
|
@ -50,6 +54,7 @@ public class GetModel {
|
|||
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private int retCode = ResponseCode.RequestError;
|
||||
|
||||
private GetModel() {
|
||||
}
|
||||
|
@ -72,6 +77,10 @@ public class GetModel {
|
|||
return localRef;
|
||||
}
|
||||
|
||||
public int getRetCode() {
|
||||
return retCode;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a flatBuffer builder of RequestGetModel.
|
||||
*
|
||||
|
@ -88,8 +97,8 @@ public class GetModel {
|
|||
return builder.iteration(iteration).flName(name).time().build();
|
||||
}
|
||||
|
||||
|
||||
private FLClientStatus parseResponseAlbert(ResponseGetModel responseDataBuf) {
|
||||
private FLClientStatus deprecatedParseResponseAlbert(ResponseGetModel responseDataBuf) {
|
||||
FLClientStatus status;
|
||||
int fmCount = responseDataBuf.featureMapLength();
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
|
||||
|
@ -113,6 +122,11 @@ public class GetModel {
|
|||
LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength:" +
|
||||
" " + feature.dataLength()));
|
||||
}
|
||||
status = Common.initSession(flParameter.getInferModelPath());
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference " +
|
||||
"model-----------------"));
|
||||
|
@ -131,6 +145,7 @@ public class GetModel {
|
|||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
Common.freeSession();
|
||||
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
|
||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||
|
@ -145,6 +160,11 @@ public class GetModel {
|
|||
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " +
|
||||
feature.dataLength()));
|
||||
}
|
||||
status = Common.initSession(flParameter.getInferModelPath());
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
|
@ -154,11 +174,13 @@ public class GetModel {
|
|||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
Common.freeSession();
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
private FLClientStatus parseResponseLenet(ResponseGetModel responseDataBuf) {
|
||||
private FLClientStatus deprecatedParseResponseLenet(ResponseGetModel responseDataBuf) {
|
||||
FLClientStatus status;
|
||||
int fmCount = responseDataBuf.featureMapLength();
|
||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
|
@ -172,6 +194,11 @@ public class GetModel {
|
|||
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " +
|
||||
feature.dataLength()));
|
||||
}
|
||||
status = Common.initSession(flParameter.getInferModelPath());
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
|
@ -180,6 +207,116 @@ public class GetModel {
|
|||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
Common.freeSession();
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
private FLClientStatus deprecatedParseFeatures(ResponseGetModel responseDataBuf) {
|
||||
FLClientStatus status = FLClientStatus.SUCCESS;
|
||||
if (ALBERT.equals(flParameter.getFlName())) {
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
|
||||
status = deprecatedParseResponseAlbert(responseDataBuf);
|
||||
} else if (LENET.equals(flParameter.getFlName())) {
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>"));
|
||||
status = deprecatedParseResponseLenet(responseDataBuf);
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[getModel] the flName is not valid, only support: lenet, albert"));
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
private FLClientStatus parseResponseFeatures(ResponseGetModel responseDataBuf) {
|
||||
FLClientStatus status;
|
||||
Client client = ClientManager.getClient(flParameter.getFlName());
|
||||
int fmCount = responseDataBuf.featureMapLength();
|
||||
if (fmCount <= 0) {
|
||||
LOGGER.severe(Common.addTag("[getModel] the feature size get from server is zero"));
|
||||
retCode = ResponseCode.SystemError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[getModel] parseResponseFeatures by " + localFLParameter.getServerMod()));
|
||||
ArrayList<FeatureMap> trainFeatureMaps = new ArrayList<FeatureMap>();
|
||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = responseDataBuf.featureMap(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
|
||||
retCode = ResponseCode.SystemError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
String featureName = feature.weightFullname();
|
||||
if (flParameter.getHybridWeightName(RunType.TRAINMODE).contains(featureName)) {
|
||||
trainFeatureMaps.add(feature);
|
||||
LOGGER.info(Common.addTag("[getModel] trainWeightFullname: " + feature.weightFullname() + ", " +
|
||||
"trainWeightLength: " + feature.dataLength()));
|
||||
}
|
||||
if (flParameter.getHybridWeightName(RunType.INFERMODE).contains(featureName)) {
|
||||
inferFeatureMaps.add(feature);
|
||||
LOGGER.info(Common.addTag("[getModel] inferWeightFullname: " + feature.weightFullname() + ", " +
|
||||
"inferWeightLength: " + feature.dataLength()));
|
||||
}
|
||||
}
|
||||
Status tag;
|
||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference " +
|
||||
"model-----------------"));
|
||||
status = Common.initSession(flParameter.getInferModelPath());
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
tag = client.updateFeatures(flParameter.getInferModelPath(), inferFeatureMaps);
|
||||
Common.freeSession();
|
||||
if (!Status.SUCCESS.equals(tag)) {
|
||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <Client.updateFeatures>"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into train model-----------------"));
|
||||
status = Common.initSession(flParameter.getTrainModelPath());
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
tag = client.updateFeatures(flParameter.getTrainModelPath(), inferFeatureMaps);
|
||||
Common.freeSession();
|
||||
if (!Status.SUCCESS.equals(tag)) {
|
||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <Client.updateFeatures>"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
||||
LOGGER.info(Common.addTag("[getModel] parseResponseFeatures by " + localFLParameter.getServerMod()));
|
||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = responseDataBuf.featureMap(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
|
||||
retCode = ResponseCode.SystemError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
String featureName = feature.weightFullname();
|
||||
featureMaps.add(feature);
|
||||
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", " +
|
||||
"weightLength: " + feature.dataLength()));
|
||||
}
|
||||
Status tag;
|
||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
|
||||
status = Common.initSession(flParameter.getTrainModelPath());
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
tag = client.updateFeatures(flParameter.getTrainModelPath(), featureMaps);
|
||||
LOGGER.info(Common.addTag("[getModel] ===========free session============="));
|
||||
Common.freeSession();
|
||||
if (!Status.SUCCESS.equals(tag)) {
|
||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <Client.updateFeatures>"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -190,25 +327,21 @@ public class GetModel {
|
|||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public FLClientStatus doResponse(ResponseGetModel responseDataBuf) {
|
||||
LOGGER.info(Common.addTag("[getModel] ==========get model content is:================"));
|
||||
LOGGER.info(Common.addTag("[getModel] ==========retCode: " + responseDataBuf.retcode()));
|
||||
retCode = responseDataBuf.retcode();
|
||||
LOGGER.info(Common.addTag("[getModel] ==========the response message of getModel is:================"));
|
||||
LOGGER.info(Common.addTag("[getModel] ==========retCode: " + retCode));
|
||||
LOGGER.info(Common.addTag("[getModel] ==========reason: " + responseDataBuf.reason()));
|
||||
LOGGER.info(Common.addTag("[getModel] ==========iteration: " + responseDataBuf.iteration()));
|
||||
LOGGER.info(Common.addTag("[getModel] ==========time: " + responseDataBuf.timestamp()));
|
||||
FLClientStatus status = FLClientStatus.SUCCESS;
|
||||
int retCode = responseDataBuf.retcode();
|
||||
switch (retCode) {
|
||||
switch (responseDataBuf.retcode()) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[getModel] getModel response success"));
|
||||
if (ALBERT.equals(flParameter.getFlName())) {
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
|
||||
status = parseResponseAlbert(responseDataBuf);
|
||||
} else if (LENET.equals(flParameter.getFlName())) {
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>"));
|
||||
status = parseResponseLenet(responseDataBuf);
|
||||
if (Common.checkFLName(flParameter.getFlName())) {
|
||||
status = deprecatedParseFeatures(responseDataBuf);
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[getModel] the flName is not valid, only support: lenet, albert"));
|
||||
throw new IllegalArgumentException();
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseFeatures>"));
|
||||
status = parseResponseFeatures(responseDataBuf);
|
||||
}
|
||||
return status;
|
||||
case (ResponseCode.SucNotReady):
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import com.mindspore.lite.config.MSConfig;
|
||||
|
||||
import org.bouncycastle.math.ec.rfc7748.X25519;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -59,6 +61,16 @@ public class LocalFLParameter {
|
|||
* The model name supported by federated learning tasks: "albert".
|
||||
*/
|
||||
public static final String ALBERT = "albert";
|
||||
|
||||
/**
|
||||
* The deployment environment supported by federated learning tasks: "android".
|
||||
*/
|
||||
public static final String ANDROID = "android";
|
||||
|
||||
/**
|
||||
* The deployment environment supported by federated learning tasks: "x86".
|
||||
*/
|
||||
public static final String X86 = "x86";
|
||||
private static volatile LocalFLParameter localFLParameter;
|
||||
|
||||
private List<String> classifierWeightName = new ArrayList<>();
|
||||
|
@ -67,6 +79,9 @@ public class LocalFLParameter {
|
|||
private String encryptLevel = EncryptLevel.NOT_ENCRYPT.toString();
|
||||
private String earlyStopMod = EarlyStopMod.NOT_EARLY_STOP.toString();
|
||||
private String serverMod = ServerMod.HYBRID_TRAINING.toString();
|
||||
private boolean stopJobFlag = false;
|
||||
private MSConfig msConfig = new MSConfig();
|
||||
private boolean useSSL = true;
|
||||
|
||||
private LocalFLParameter() {
|
||||
// set classifierWeightName albertWeightName
|
||||
|
@ -143,14 +158,14 @@ public class LocalFLParameter {
|
|||
public void setEncryptLevel(String encryptLevel) {
|
||||
if (encryptLevel == null || encryptLevel.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <encryptLevel> is null, please check it " +
|
||||
"before set"));
|
||||
"before setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
if ((!EncryptLevel.DP_ENCRYPT.toString().equals(encryptLevel)) &&
|
||||
(!EncryptLevel.NOT_ENCRYPT.toString().equals(encryptLevel)) &&
|
||||
(!EncryptLevel.PW_ENCRYPT.toString().equals(encryptLevel))) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <encryptLevel> is " + encryptLevel + " ," +
|
||||
" it must be DP_ENCRYPT or NOT_ENCRYPT or PW_ENCRYPT, please check it before set"));
|
||||
" it must be DP_ENCRYPT or NOT_ENCRYPT or PW_ENCRYPT, please check it before setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.encryptLevel = encryptLevel;
|
||||
|
@ -163,7 +178,7 @@ public class LocalFLParameter {
|
|||
public void setEarlyStopMod(String earlyStopMod) {
|
||||
if (earlyStopMod == null || earlyStopMod.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <earlyStopMod> is null, please check it " +
|
||||
"before set"));
|
||||
"before setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
if ((!EarlyStopMod.NOT_EARLY_STOP.toString().equals(earlyStopMod)) &&
|
||||
|
@ -171,7 +186,8 @@ public class LocalFLParameter {
|
|||
(!EarlyStopMod.LOSS_DIFF.toString().equals(earlyStopMod)) &&
|
||||
(!EarlyStopMod.WEIGHT_DIFF.toString().equals(earlyStopMod))) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <earlyStopMod> is " + earlyStopMod + " ," +
|
||||
" it must be NOT_EARLY_STOP or LOSS_ABS or LOSS_DIFF or WEIGHT_DIFF, please check it before set"));
|
||||
" it must be NOT_EARLY_STOP or LOSS_ABS or LOSS_DIFF or WEIGHT_DIFF, please check it before " +
|
||||
"setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.earlyStopMod = earlyStopMod;
|
||||
|
@ -184,15 +200,44 @@ public class LocalFLParameter {
|
|||
public void setServerMod(String serverMod) {
|
||||
if (serverMod == null || serverMod.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <serverMod> is null, please check it " +
|
||||
"before set"));
|
||||
"before setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
if ((!ServerMod.HYBRID_TRAINING.toString().equals(serverMod)) &&
|
||||
(!ServerMod.FEDERATED_LEARNING.toString().equals(serverMod))) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <serverMod> is " + serverMod + " , it " +
|
||||
"must be HYBRID_TRAINING or FEDERATED_LEARNING, please check it before set"));
|
||||
"must be HYBRID_TRAINING or FEDERATED_LEARNING, please check it before setting"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.serverMod = serverMod;
|
||||
}
|
||||
|
||||
public boolean isStopJobFlag() {
|
||||
return stopJobFlag;
|
||||
}
|
||||
|
||||
public void setStopJobFlag(boolean stopJobFlag) {
|
||||
this.stopJobFlag = stopJobFlag;
|
||||
}
|
||||
|
||||
public MSConfig getMsConfig() {
|
||||
return msConfig;
|
||||
}
|
||||
|
||||
public void setMsConfig(int DeviceType, int threadNum, int cpuBindMode, boolean enable_fp16) {
|
||||
// arg 0: DeviceType:DT_CPU -> 0
|
||||
// arg 1: ThreadNum -> 2
|
||||
// arg 2: cpuBindMode:NO_BIND -> 0
|
||||
// arg 3: enable_fp16 -> false
|
||||
msConfig.init(DeviceType, threadNum, cpuBindMode, enable_fp16);
|
||||
}
|
||||
|
||||
|
||||
public boolean isUseSSL() {
|
||||
return useSSL;
|
||||
}
|
||||
|
||||
public void setUseSSL(boolean useSSL) {
|
||||
this.useSSL = useSSL;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -68,7 +68,7 @@ public class SSLSocketFactoryTools {
|
|||
String domainName = flParameter.getDomainName();
|
||||
if ((domainName == null || domainName.isEmpty() || domainName.split("//").length < 2)) {
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the <domainName> is null or not valid, it should" +
|
||||
" be like as https://...... , please check!"));
|
||||
" be like https://...... , please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
if (domainName.split("//")[1].split(":").length < 2) {
|
||||
|
@ -87,7 +87,7 @@ public class SSLSocketFactoryTools {
|
|||
|
||||
private void initSslSocketFactory() {
|
||||
try {
|
||||
sslContext = SSLContext.getInstance("TLS");
|
||||
sslContext = SSLContext.getInstance(flParameter.getSslProtocol());
|
||||
x509Certificate = readCert(flParameter.getCertPath());
|
||||
myTrustManager = new MyTrustManager(x509Certificate);
|
||||
sslContext.init(null, new TrustManager[]{
|
||||
|
@ -137,9 +137,8 @@ public class SSLSocketFactoryTools {
|
|||
"convert to X509Certificate"));
|
||||
}
|
||||
} catch (FileNotFoundException | CertificateException ex) {
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch FileNotFoundException or CertificateException " +
|
||||
"when creating " +
|
||||
"CertificateFactory in readCert: " + ex.getMessage()));
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch exception when creating CertificateFactory in " +
|
||||
"readCert: invalid file or CertificateException"));
|
||||
} finally {
|
||||
try {
|
||||
if (inputStream != null) {
|
||||
|
|
|
@ -16,20 +16,12 @@
|
|||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
|
||||
import mindspore.schema.FeatureMap;
|
||||
|
||||
import java.security.SecureRandom;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
|
@ -52,7 +44,7 @@ public class SecureProtocol {
|
|||
private double dpEps;
|
||||
private double dpDelta;
|
||||
private double dpNormClip;
|
||||
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
|
||||
private ArrayList<String> updateFeatureName = new ArrayList<String>();
|
||||
private int retCode;
|
||||
|
||||
/**
|
||||
|
@ -115,17 +107,17 @@ public class SecureProtocol {
|
|||
*
|
||||
* @return the feature names that needed to be encrypted.
|
||||
*/
|
||||
public ArrayList<String> getEncryptFeatureName() {
|
||||
return encryptFeatureName;
|
||||
public ArrayList<String> getUpdateFeatureName() {
|
||||
return updateFeatureName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the parameter encryptFeatureName.
|
||||
* Set the parameter updateFeatureName.
|
||||
*
|
||||
* @param encryptFeatureName the feature names that needed to be encrypted.
|
||||
* @param updateFeatureName the feature names that needed to be encrypted.
|
||||
*/
|
||||
public void setEncryptFeatureName(ArrayList<String> encryptFeatureName) {
|
||||
this.encryptFeatureName = encryptFeatureName;
|
||||
public void setUpdateFeatureName(ArrayList<String> updateFeatureName) {
|
||||
this.updateFeatureName = updateFeatureName;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -146,6 +138,10 @@ public class SecureProtocol {
|
|||
LOGGER.info(String.format("[PairWiseMask] ==============request flID: %s ==============",
|
||||
localFLParameter.getFlID()));
|
||||
// round 0
|
||||
if (localFLParameter.isStopJobFlag()) {
|
||||
LOGGER.info(Common.addTag("the stopJObFlag is set to true, the job will be stop"));
|
||||
return status;
|
||||
}
|
||||
status = cipherClient.exchangeKeys();
|
||||
retCode = cipherClient.getRetCode();
|
||||
LOGGER.info(String.format("[PairWiseMask] ============= RequestExchangeKeys+GetExchangeKeys response: %s ",
|
||||
|
@ -154,6 +150,10 @@ public class SecureProtocol {
|
|||
return status;
|
||||
}
|
||||
// round 1
|
||||
if (localFLParameter.isStopJobFlag()) {
|
||||
LOGGER.info(Common.addTag("the stopJObFlag is set to true, the job will be stop"));
|
||||
return status;
|
||||
}
|
||||
status = cipherClient.shareSecrets();
|
||||
retCode = cipherClient.getRetCode();
|
||||
LOGGER.info(String.format("[Encrypt] =============RequestShareSecrets+GetShareSecrets response: %s ",
|
||||
|
@ -162,6 +162,10 @@ public class SecureProtocol {
|
|||
return status;
|
||||
}
|
||||
// round2
|
||||
if (localFLParameter.isStopJobFlag()) {
|
||||
LOGGER.info(Common.addTag("the stopJObFlag is set to true, the job will be stop"));
|
||||
return status;
|
||||
}
|
||||
featureMask = cipherClient.doubleMaskingWeight();
|
||||
if (featureMask == null || featureMask.length <= 0) {
|
||||
LOGGER.severe(Common.addTag("[Encrypt] the returned featureMask from cipherClient.doubleMaskingWeight" +
|
||||
|
@ -180,30 +184,18 @@ public class SecureProtocol {
|
|||
* @param trainDataSize trainDataSize tne size of train data set.
|
||||
* @return the serialized model weights after adding masks.
|
||||
*/
|
||||
public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize) {
|
||||
public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize, Map<String, float[]> trainedMap) {
|
||||
if (featureMask == null || featureMask.length == 0) {
|
||||
LOGGER.severe("[Encrypt] feature mask is null, please check");
|
||||
return new int[0];
|
||||
}
|
||||
LOGGER.info(String.format("[Encrypt] feature mask size: %s", featureMask.length));
|
||||
// get feature map
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[Encrypt] the flName is not valid, only support: lenet, albert"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
int featureSize = encryptFeatureName.size();
|
||||
int featureSize = updateFeatureName.size();
|
||||
int[] featuresMap = new int[featureSize];
|
||||
int maskIndex = 0;
|
||||
for (int i = 0; i < featureSize; i++) {
|
||||
String key = encryptFeatureName.get(i);
|
||||
float[] data = map.get(key);
|
||||
String key = updateFeatureName.get(i);
|
||||
float[] data = trainedMap.get(key);
|
||||
LOGGER.info(String.format("[Encrypt] feature name: %s feature size: %s", key, data.length));
|
||||
for (int j = 0; j < data.length; j++) {
|
||||
float rawData = data[j];
|
||||
|
@ -349,21 +341,10 @@ public class SecureProtocol {
|
|||
* @param trainDataSize tne size of train data set.
|
||||
* @return the serialized model weights after adding masks.
|
||||
*/
|
||||
public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize) {
|
||||
public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize, Map<String, float[]> trainedMap) {
|
||||
// get feature map
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[Encrypt] the flName is not valid, only support: lenet, albert"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
Map<String, float[]> mapBeforeTrain = modelMap;
|
||||
int featureSize = encryptFeatureName.size();
|
||||
int featureSize = updateFeatureName.size();
|
||||
// calculate sigma
|
||||
double gaussianSigma = calculateSigma(dpNormClip, dpEps, dpDelta);
|
||||
LOGGER.info(Common.addTag("[Encrypt] =============Noise sigma of DP is: " + gaussianSigma + "============="));
|
||||
|
@ -371,8 +352,8 @@ public class SecureProtocol {
|
|||
// calculate l2-norm of all layers' update array
|
||||
double updateL2Norm = 0d;
|
||||
for (int i = 0; i < featureSize; i++) {
|
||||
String key = encryptFeatureName.get(i);
|
||||
float[] data = map.get(key);
|
||||
String key = updateFeatureName.get(i);
|
||||
float[] data = trainedMap.get(key);
|
||||
float[] dataBeforeTrain = mapBeforeTrain.get(key);
|
||||
for (int j = 0; j < data.length; j++) {
|
||||
float rawData = data[j];
|
||||
|
@ -391,12 +372,12 @@ public class SecureProtocol {
|
|||
// clip and add noise
|
||||
int[] featuresMap = new int[featureSize];
|
||||
for (int i = 0; i < featureSize; i++) {
|
||||
String key = encryptFeatureName.get(i);
|
||||
if (!map.containsKey(key)) {
|
||||
String key = updateFeatureName.get(i);
|
||||
if (!trainedMap.containsKey(key)) {
|
||||
LOGGER.severe("[Encrypt] the key: " + key + " is not in map, please check!");
|
||||
return new int[0];
|
||||
}
|
||||
float[] data = map.get(key);
|
||||
float[] data = trainedMap.get(key);
|
||||
float[] data2 = new float[data.length];
|
||||
if (!mapBeforeTrain.containsKey(key)) {
|
||||
LOGGER.severe("[Encrypt] the key: " + key + " is not in mapBeforeTrain, please check!");
|
||||
|
|
|
@ -1,29 +1,21 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
|
||||
import com.mindspore.flclient.model.AlInferBert;
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.Client;
|
||||
import com.mindspore.flclient.model.ClientManager;
|
||||
import com.mindspore.flclient.model.RunType;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.Status;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import com.mindspore.flclient.pki.PkiBean;
|
||||
import com.mindspore.flclient.pki.PkiUtil;
|
||||
|
||||
import mindspore.schema.FLPlan;
|
||||
import mindspore.schema.FeatureMap;
|
||||
|
@ -31,9 +23,14 @@ import mindspore.schema.RequestFLJob;
|
|||
import mindspore.schema.ResponseCode;
|
||||
import mindspore.schema.ResponseFLJob;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.security.cert.Certificate;
|
||||
import java.util.ArrayList;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
/**
|
||||
* StartFLJob
|
||||
*
|
||||
|
@ -52,11 +49,425 @@ public class StartFLJob {
|
|||
|
||||
private int featureSize;
|
||||
private String nextRequestTime;
|
||||
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
|
||||
private ArrayList<String> updateFeatureName = new ArrayList<String>();
|
||||
private int retCode = ResponseCode.RequestError;
|
||||
private float lr = (float) 0.1;
|
||||
private int batchSize;
|
||||
|
||||
private StartFLJob() {
|
||||
}
|
||||
|
||||
/**
|
||||
* getInstance of StartFLJob
|
||||
*
|
||||
* @return StartFLJob instance
|
||||
*/
|
||||
public static StartFLJob getInstance() {
|
||||
StartFLJob localRef = startFLJob;
|
||||
if (localRef == null) {
|
||||
synchronized (StartFLJob.class) {
|
||||
localRef = startFLJob;
|
||||
if (localRef == null) {
|
||||
startFLJob = localRef = new StartFLJob();
|
||||
}
|
||||
}
|
||||
}
|
||||
return localRef;
|
||||
}
|
||||
|
||||
public String getNextRequestTime() {
|
||||
return nextRequestTime;
|
||||
}
|
||||
|
||||
public int getRetCode() {
|
||||
return retCode;
|
||||
}
|
||||
|
||||
/**
|
||||
* get request start FLJob
|
||||
*
|
||||
* @param dataSize dataSize
|
||||
* @param iteration iteration
|
||||
* @param time time
|
||||
* @param pkiBean pki bean
|
||||
* @return byte[] data
|
||||
*/
|
||||
public byte[] getRequestStartFLJob(int dataSize, int iteration, long time, PkiBean pkiBean) {
|
||||
RequestStartFLJobBuilder builder = new RequestStartFLJobBuilder();
|
||||
|
||||
if (flParameter.isPkiVerify()) {
|
||||
if (pkiBean == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the parameter of <pkiBean> is null, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return builder.flName(flParameter.getFlName())
|
||||
.time(time)
|
||||
.id(localFLParameter.getFlID())
|
||||
.dataSize(dataSize)
|
||||
.iteration(iteration)
|
||||
.signData(pkiBean.getSignData())
|
||||
.certificateChain(pkiBean.getCertificates())
|
||||
.build();
|
||||
}
|
||||
return builder.flName(flParameter.getFlName())
|
||||
.time(time)
|
||||
.id(localFLParameter.getFlID())
|
||||
.dataSize(dataSize)
|
||||
.iteration(iteration)
|
||||
.build();
|
||||
}
|
||||
|
||||
public int getFeatureSize() {
|
||||
return featureSize;
|
||||
}
|
||||
|
||||
public ArrayList<String> getUpdateFeatureName() {
|
||||
return updateFeatureName;
|
||||
}
|
||||
|
||||
private FLClientStatus deprecatedParseResponseAlbert(ResponseFLJob flJob) {
|
||||
FLClientStatus status;
|
||||
int fmCount = flJob.featureMapLength();
|
||||
updateFeatureName.clear();
|
||||
if (fmCount <= 0) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] parseResponseAlbert by " + localFLParameter.getServerMod()));
|
||||
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
|
||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
String featureName = feature.weightFullname();
|
||||
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
|
||||
albertFeatureMaps.add(feature);
|
||||
inferFeatureMaps.add(feature);
|
||||
featureSize += feature.dataLength();
|
||||
updateFeatureName.add(feature.weightFullname());
|
||||
} else if (localFLParameter.getClassifierWeightName().contains(featureName)) {
|
||||
inferFeatureMaps.add(feature);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " +
|
||||
"weightLength: " + feature.dataLength()));
|
||||
}
|
||||
status = Common.initSession(flParameter.getTrainModelPath());
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference " +
|
||||
"model-----------------"));
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(),
|
||||
inferFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------"));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
|
||||
albertFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
Common.freeSession();
|
||||
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] parseResponseAlbert by " + localFLParameter.getServerMod()));
|
||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
String featureName = feature.weightFullname();
|
||||
featureMaps.add(feature);
|
||||
featureSize += feature.dataLength();
|
||||
updateFeatureName.add(featureName);
|
||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " +
|
||||
"weightLength: " + feature.dataLength()));
|
||||
}
|
||||
status = Common.initSession(flParameter.getTrainModelPath());
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
|
||||
featureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
Common.freeSession();
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
private FLClientStatus deprecatedParseResponseLenet(ResponseFLJob flJob) {
|
||||
FLClientStatus status;
|
||||
int fmCount = flJob.featureMapLength();
|
||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||
updateFeatureName.clear();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
String featureName = feature.weightFullname();
|
||||
featureMaps.add(feature);
|
||||
featureSize += feature.dataLength();
|
||||
updateFeatureName.add(featureName);
|
||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " +
|
||||
feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
||||
}
|
||||
status = Common.initSession(flParameter.getTrainModelPath());
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
tag = SessionUtil.updateFeatures(trainLenet.getTrainSession(), flParameter.getTrainModelPath(), featureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
Common.freeSession();
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
private FLClientStatus hybridFeatures(ResponseFLJob flJob) {
|
||||
FLClientStatus status;
|
||||
Client client = ClientManager.getClient(flParameter.getFlName());
|
||||
int fmCount = flJob.featureMapLength();
|
||||
ArrayList<FeatureMap> trainFeatureMaps = new ArrayList<FeatureMap>();
|
||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||
retCode = ResponseCode.SystemError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
String featureName = feature.weightFullname();
|
||||
if (flParameter.getHybridWeightName(RunType.TRAINMODE).contains(featureName)) {
|
||||
trainFeatureMaps.add(feature);
|
||||
featureSize += feature.dataLength();
|
||||
updateFeatureName.add(feature.weightFullname());
|
||||
LOGGER.info(Common.addTag("[startFLJob] trainWeightFullname: " + feature.weightFullname() + ", " +
|
||||
"trainWeightLength: " + feature.dataLength()));
|
||||
}
|
||||
if (flParameter.getHybridWeightName(RunType.INFERMODE).contains(featureName)) {
|
||||
inferFeatureMaps.add(feature);
|
||||
LOGGER.info(Common.addTag("[startFLJob] inferWeightFullname: " + feature.weightFullname() + ", " +
|
||||
"inferWeightLength: " + feature.dataLength()));
|
||||
}
|
||||
}
|
||||
Status tag;
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference " +
|
||||
"model-----------------"));
|
||||
status = Common.initSession(flParameter.getInferModelPath());
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
tag = client.updateFeatures(flParameter.getInferModelPath(), inferFeatureMaps);
|
||||
Common.freeSession();
|
||||
if (!Status.SUCCESS.equals(tag)) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <Client.updateFeatures>"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------"));
|
||||
status = Common.initSession(flParameter.getTrainModelPath());
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] set <lr> for client: " + lr));
|
||||
tag = client.setLearningRate(lr);
|
||||
if (!Status.SUCCESS.equals(tag)) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] setLearningRate failed, return -1, please check"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] set <batch size> for client: " + batchSize));
|
||||
client.setBatchSize(batchSize);
|
||||
tag = client.updateFeatures(flParameter.getTrainModelPath(), inferFeatureMaps);
|
||||
Common.freeSession();
|
||||
if (!Status.SUCCESS.equals(tag)) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <Client.updateFeatures>"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
private FLClientStatus normalFeatures(ResponseFLJob flJob) {
|
||||
FLClientStatus status;
|
||||
Client client = ClientManager.getClient(flParameter.getFlName());
|
||||
int fmCount = flJob.featureMapLength();
|
||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||
retCode = ResponseCode.SystemError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
String featureName = feature.weightFullname();
|
||||
featureMaps.add(feature);
|
||||
featureSize += feature.dataLength();
|
||||
updateFeatureName.add(featureName);
|
||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " +
|
||||
"weightLength: " + feature.dataLength()));
|
||||
}
|
||||
Status tag;
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
|
||||
status = Common.initSession(flParameter.getTrainModelPath());
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] set <lr> for client: " + lr));
|
||||
tag = client.setLearningRate(lr);
|
||||
if (!Status.SUCCESS.equals(tag)) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] setLearningRate failed, return -1, please check"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] set <batch size> for client: " + batchSize));
|
||||
client.setBatchSize(batchSize);
|
||||
tag = client.updateFeatures(flParameter.getTrainModelPath(), featureMaps);
|
||||
LOGGER.info(Common.addTag("[startFLJob] ===========free session============="));
|
||||
Common.freeSession();
|
||||
if (!Status.SUCCESS.equals(tag)) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <Client.updateFeatures>"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
private FLClientStatus parseResponseFeatures(ResponseFLJob flJob) {
|
||||
FLClientStatus status;
|
||||
int fmCount = flJob.featureMapLength();
|
||||
updateFeatureName.clear();
|
||||
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] parseResponseFeatures by " + localFLParameter.getServerMod()));
|
||||
status = hybridFeatures(flJob);
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
return status;
|
||||
}
|
||||
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] parseResponseFeatures by " + localFLParameter.getServerMod()));
|
||||
status = normalFeatures(flJob);
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
private FLClientStatus deprecatedParseFeatures(ResponseFLJob flJob) {
|
||||
FLClientStatus status = FLClientStatus.SUCCESS;
|
||||
if (ALBERT.equals(flParameter.getFlName())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseAlbert>"));
|
||||
status = deprecatedParseResponseAlbert(flJob);
|
||||
}
|
||||
if (LENET.equals(flParameter.getFlName())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseLenet>"));
|
||||
status = deprecatedParseResponseLenet(flJob);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* response res
|
||||
*
|
||||
* @param flJob ResponseFLJob
|
||||
* @return FLClientStatus
|
||||
*/
|
||||
public FLClientStatus doResponse(ResponseFLJob flJob) {
|
||||
if (flJob == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the input parameter flJob is null"));
|
||||
retCode = ResponseCode.SystemError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
FLPlan flPlanConfig = flJob.flPlanConfig();
|
||||
if (flPlanConfig == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the flPlanConfig is null"));
|
||||
retCode = ResponseCode.SystemError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
|
||||
if (flJob.featureMapLength() <= 0) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero"));
|
||||
retCode = ResponseCode.SystemError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
|
||||
retCode = flJob.retcode();
|
||||
LOGGER.info(Common.addTag("[startFLJob] ==========the response message of startFLJob is:================"));
|
||||
LOGGER.info(Common.addTag("[startFLJob] return retCode: " + retCode));
|
||||
LOGGER.info(Common.addTag("[startFLJob] reason: " + flJob.reason()));
|
||||
LOGGER.info(Common.addTag("[startFLJob] iteration: " + flJob.iteration()));
|
||||
LOGGER.info(Common.addTag("[startFLJob] is selected: " + flJob.isSelected()));
|
||||
LOGGER.info(Common.addTag("[startFLJob] next request time: " + flJob.nextReqTime()));
|
||||
nextRequestTime = flJob.nextReqTime();
|
||||
LOGGER.info(Common.addTag("[startFLJob] timestamp: " + flJob.timestamp()));
|
||||
FLClientStatus status;
|
||||
int responseRetCode = flJob.retcode();
|
||||
|
||||
switch (responseRetCode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
localFLParameter.setServerMod(flPlanConfig.serverMode());
|
||||
if (flPlanConfig.lr() != 0) {
|
||||
lr = flPlanConfig.lr();
|
||||
} else {
|
||||
LOGGER.info(Common.addTag("[startFLJob] the GlobalParameter <lr> from server: " + lr + " is not " +
|
||||
"valid, " +
|
||||
"will use the default value 0.1"));
|
||||
}
|
||||
batchSize = flPlanConfig.miniBatch();
|
||||
if (Common.checkFLName(flParameter.getFlName())) {
|
||||
status = deprecatedParseFeatures(flJob);
|
||||
} else {
|
||||
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseFeatures>"));
|
||||
status = parseResponseFeatures(flJob);
|
||||
}
|
||||
return status;
|
||||
case (ResponseCode.OutOfTime):
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
case (ResponseCode.SystemError):
|
||||
LOGGER.info(Common.addTag("[startFLJob] catch RequestError or SystemError"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the return <retCode> from server is invalid: " + retCode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
class RequestStartFLJobBuilder {
|
||||
private RequestFLJob requestFLJob;
|
||||
private FlatBufferBuilder builder;
|
||||
|
@ -65,6 +476,11 @@ public class StartFLJob {
|
|||
private int dataSize = 0;
|
||||
private int timestampOffset = 0;
|
||||
private int idOffset = 0;
|
||||
private int signDataOffset = 0;
|
||||
private int keyAttestationOffset = 0;
|
||||
private int equipCertOffset = 0;
|
||||
private int equipCACertOffset = 0;
|
||||
private int rootCertOffset = 0;
|
||||
|
||||
public RequestStartFLJobBuilder() {
|
||||
builder = new FlatBufferBuilder();
|
||||
|
@ -135,6 +551,50 @@ public class StartFLJob {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* signData
|
||||
*
|
||||
* @param signData byte[]
|
||||
* @return RequestStartFLJobBuilder
|
||||
*/
|
||||
public RequestStartFLJobBuilder signData(byte[] signData) {
|
||||
if (signData == null || signData.length == 0) {
|
||||
LOGGER.severe(
|
||||
Common.addTag("[startFLJob] the parameter of <signData> is null or empty, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.signDataOffset = RequestFLJob.createSignDataVector(builder, signData);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* set certificateChain
|
||||
*
|
||||
* @param certificates Certificate array
|
||||
* @return RequestStartFLJobBuilder
|
||||
*/
|
||||
public RequestStartFLJobBuilder certificateChain(Certificate[] certificates) {
|
||||
if (certificates == null || certificates.length < 4) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the parameter of <certificates> is null or the length "
|
||||
+ "is not valid (should be >= 4), please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
try {
|
||||
String keyAttestationPem = PkiUtil.getPemFormat(certificates[0]);
|
||||
String equipCertPem = PkiUtil.getPemFormat(certificates[1]);
|
||||
String equipCACertPem = PkiUtil.getPemFormat(certificates[2]);
|
||||
String rootCertPem = PkiUtil.getPemFormat(certificates[3]);
|
||||
|
||||
this.keyAttestationOffset = this.builder.createString(keyAttestationPem);
|
||||
this.equipCertOffset = this.builder.createString(equipCertPem);
|
||||
this.equipCACertOffset = this.builder.createString(equipCACertPem);
|
||||
this.rootCertOffset = this.builder.createString(rootCertPem);
|
||||
} catch (IOException e) {
|
||||
LOGGER.severe(Common.addTag("[StartFLJob] catch IOException in certificateChain: " + e.getMessage()));
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* build protobuffer
|
||||
*
|
||||
|
@ -152,210 +612,4 @@ public class StartFLJob {
|
|||
return builder.sizedByteArray();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* getInstance of StartFLJob
|
||||
*
|
||||
* @return StartFLJob instance
|
||||
*/
|
||||
public static StartFLJob getInstance() {
|
||||
StartFLJob localRef = startFLJob;
|
||||
if (localRef == null) {
|
||||
synchronized (StartFLJob.class) {
|
||||
localRef = startFLJob;
|
||||
if (localRef == null) {
|
||||
startFLJob = localRef = new StartFLJob();
|
||||
}
|
||||
}
|
||||
}
|
||||
return localRef;
|
||||
}
|
||||
|
||||
public String getNextRequestTime() {
|
||||
return nextRequestTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* get request start FLJob
|
||||
*
|
||||
* @param dataSize dataSize
|
||||
* @param iteration iteration
|
||||
* @param time time
|
||||
* @return byte[] data
|
||||
*/
|
||||
public byte[] getRequestStartFLJob(int dataSize, int iteration, long time) {
|
||||
RequestStartFLJobBuilder builder = new RequestStartFLJobBuilder();
|
||||
return builder.flName(flParameter.getFlName())
|
||||
.time(time)
|
||||
.id(localFLParameter.getFlID())
|
||||
.dataSize(dataSize)
|
||||
.iteration(iteration)
|
||||
.build();
|
||||
}
|
||||
|
||||
public int getFeatureSize() {
|
||||
return featureSize;
|
||||
}
|
||||
|
||||
public ArrayList<String> getEncryptFeatureName() {
|
||||
return encryptFeatureName;
|
||||
}
|
||||
|
||||
|
||||
private FLClientStatus parseResponseAlbert(ResponseFLJob flJob) {
|
||||
int fmCount = flJob.featureMapLength();
|
||||
encryptFeatureName.clear();
|
||||
if (fmCount <= 0) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] parseResponseAlbert by " + localFLParameter.getServerMod()));
|
||||
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
|
||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
String featureName = feature.weightFullname();
|
||||
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
|
||||
albertFeatureMaps.add(feature);
|
||||
inferFeatureMaps.add(feature);
|
||||
featureSize += feature.dataLength();
|
||||
encryptFeatureName.add(feature.weightFullname());
|
||||
} else if (localFLParameter.getClassifierWeightName().contains(featureName)) {
|
||||
inferFeatureMaps.add(feature);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " +
|
||||
"weightLength: " + feature.dataLength()));
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference " +
|
||||
"model-----------------"));
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(),
|
||||
inferFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------"));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
|
||||
albertFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] parseResponseAlbert by " + localFLParameter.getServerMod()));
|
||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
String featureName = feature.weightFullname();
|
||||
featureMaps.add(feature);
|
||||
featureSize += feature.dataLength();
|
||||
encryptFeatureName.add(featureName);
|
||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " +
|
||||
"weightLength: " + feature.dataLength()));
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
|
||||
featureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
private FLClientStatus parseResponseLenet(ResponseFLJob flJob) {
|
||||
int fmCount = flJob.featureMapLength();
|
||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||
encryptFeatureName.clear();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
String featureName = feature.weightFullname();
|
||||
featureMaps.add(feature);
|
||||
featureSize += feature.dataLength();
|
||||
encryptFeatureName.add(featureName);
|
||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " +
|
||||
feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
tag = SessionUtil.updateFeatures(trainLenet.getTrainSession(), flParameter.getTrainModelPath(), featureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
/**
|
||||
* response res
|
||||
*
|
||||
* @param flJob ResponseFLJob
|
||||
* @return FLClientStatus
|
||||
*/
|
||||
public FLClientStatus doResponse(ResponseFLJob flJob) {
|
||||
if (flJob == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the input parameter flJob is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
FLPlan flPlanConfig = flJob.flPlanConfig();
|
||||
if (flPlanConfig == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the flPlanConfig is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] return retCode: " + flJob.retcode()));
|
||||
LOGGER.info(Common.addTag("[startFLJob] reason: " + flJob.reason()));
|
||||
LOGGER.info(Common.addTag("[startFLJob] iteration: " + flJob.iteration()));
|
||||
LOGGER.info(Common.addTag("[startFLJob] is selected: " + flJob.isSelected()));
|
||||
LOGGER.info(Common.addTag("[startFLJob] next request time: " + flJob.nextReqTime()));
|
||||
nextRequestTime = flJob.nextReqTime();
|
||||
LOGGER.info(Common.addTag("[startFLJob] timestamp: " + flJob.timestamp()));
|
||||
FLClientStatus status = FLClientStatus.SUCCESS;
|
||||
int retCode = flJob.retcode();
|
||||
|
||||
switch (retCode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
localFLParameter.setServerMod(flPlanConfig.serverMode());
|
||||
if (ALBERT.equals(flParameter.getFlName())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseAlbert>"));
|
||||
status = parseResponseAlbert(flJob);
|
||||
}
|
||||
if (LENET.equals(flParameter.getFlName())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseLenet>"));
|
||||
status = parseResponseLenet(flJob);
|
||||
}
|
||||
return status;
|
||||
case (ResponseCode.OutOfTime):
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
case (ResponseCode.SystemError):
|
||||
LOGGER.info(Common.addTag("[startFLJob] catch RequestError or SystemError"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the return <retCode> from server is invalid: " + retCode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,19 +18,26 @@ package com.mindspore.flclient;
|
|||
|
||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.ANDROID;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
import com.mindspore.flclient.model.AlInferBert;
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.Client;
|
||||
import com.mindspore.flclient.model.ClientManager;
|
||||
import com.mindspore.flclient.model.RunType;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.Status;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
|
||||
import com.mindspore.flclient.pki.PkiUtil;
|
||||
import com.mindspore.lite.config.CpuBindMode;
|
||||
import mindspore.schema.ResponseGetModel;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.SecureRandom;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
|
@ -46,8 +53,38 @@ public class SyncFLJob {
|
|||
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private FLJobResultCallback flJobResultCallback = new FLJobResultCallback();
|
||||
private Map<String, float[]> oldFeatureMap;
|
||||
private IFLJobResultCallback flJobResultCallback;
|
||||
private FLClientStatus curStatus;
|
||||
|
||||
private void initFlIDForPkiVerify() {
|
||||
if (flParameter.isPkiVerify()) {
|
||||
LOGGER.info(Common.addTag("pkiVerify mode is open!"));
|
||||
String equipCertHash = PkiUtil.genEquipCertHash(flParameter.getClientID());
|
||||
if (equipCertHash == null || equipCertHash.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("equipCertHash is empty, please check your mobile phone, only Huawei " +
|
||||
"phones are supported now."));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
LOGGER.info(Common.addTag("flID for pki verify is: " + equipCertHash));
|
||||
localFLParameter.setFlID(equipCertHash);
|
||||
} else {
|
||||
LOGGER.info(Common.addTag("pkiVerify mode is not open!"));
|
||||
localFLParameter.setFlID(flParameter.getClientID());
|
||||
}
|
||||
}
|
||||
|
||||
public SyncFLJob() {
|
||||
if (!Common.checkFLName(flParameter.getFlName())) {
|
||||
try {
|
||||
LOGGER.info(Common.addTag("the flName: " + flParameter.getFlName()));
|
||||
Class.forName(flParameter.getFlName());
|
||||
} catch (ClassNotFoundException e) {
|
||||
LOGGER.severe(Common.addTag("catch ClassNotFoundException error, the set flName does not exist, please " +
|
||||
"check: " + e.getMessage()));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Starts a federated learning task on the device.
|
||||
|
@ -55,209 +92,172 @@ public class SyncFLJob {
|
|||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public FLClientStatus flJobRun() {
|
||||
Common.setSecureRandom(new SecureRandom());
|
||||
localFLParameter.setFlID(flParameter.getClientID());
|
||||
FLLiteClient client = new FLLiteClient();
|
||||
FLClientStatus curStatus;
|
||||
curStatus = client.initSession();
|
||||
if (curStatus == FLClientStatus.FAILED) {
|
||||
LOGGER.severe(Common.addTag("init session failed"));
|
||||
flJobResultCallback.onFlJobFinished(flParameter.getFlName(), client.getIterations(), client.getRetCode());
|
||||
return curStatus;
|
||||
flJobResultCallback = flParameter.getIflJobResultCallback();
|
||||
if (!Common.checkFLName(flParameter.getFlName()) && ANDROID.equals(flParameter.getDeployEnv())) {
|
||||
Common.setSecureRandom(Common.getFastSecureRandom());
|
||||
} else {
|
||||
Common.setSecureRandom(new SecureRandom());
|
||||
}
|
||||
|
||||
initFlIDForPkiVerify();
|
||||
localFLParameter.setMsConfig(0, flParameter.getThreadNum(), flParameter.getCpuBindMode(), false);
|
||||
FLLiteClient flLiteClient = new FLLiteClient();
|
||||
LOGGER.info(Common.addTag("recovery StopJobFlag to false in the start of fl job"));
|
||||
localFLParameter.setStopJobFlag(false);
|
||||
do {
|
||||
LOGGER.info(Common.addTag("flName: " + flParameter.getFlName()));
|
||||
int trainDataSize = client.setInput(flParameter.getTrainDataset());
|
||||
if (trainDataSize <= 0) {
|
||||
LOGGER.severe(Common.addTag("unsolved error code in <client.setInput>: the return trainDataSize<=0"));
|
||||
curStatus = FLClientStatus.FAILED;
|
||||
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(),
|
||||
client.getRetCode());
|
||||
if (checkStopJobFlag()) {
|
||||
break;
|
||||
}
|
||||
client.setTrainDataSize(trainDataSize);
|
||||
LOGGER.info(Common.addTag("flName: " + flParameter.getFlName()));
|
||||
int trainDataSize = flLiteClient.setInput();
|
||||
if (trainDataSize <= 0) {
|
||||
curStatus = FLClientStatus.FAILED;
|
||||
failed("unsolved error code in <flLiteClient.setInput>: the return trainDataSize<=0, setInput",
|
||||
flLiteClient);
|
||||
break;
|
||||
}
|
||||
flLiteClient.setTrainDataSize(trainDataSize);
|
||||
|
||||
// startFLJob
|
||||
curStatus = startFLJob(client);
|
||||
curStatus = startFLJob(flLiteClient);
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
restart("[startFLJob]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
|
||||
restart("[startFLJob]", flLiteClient.getNextRequestTime(), flLiteClient);
|
||||
continue;
|
||||
} else if (curStatus == FLClientStatus.FAILED) {
|
||||
failed("[startFLJob]", client.getIteration(), client.getRetCode(), curStatus);
|
||||
failed("[startFLJob]", flLiteClient);
|
||||
break;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] startFLJob succeed, curIteration: " + client.getIteration()));
|
||||
|
||||
// get the feature map before train
|
||||
getOldFeatureMap(client);
|
||||
LOGGER.info(Common.addTag("[startFLJob] startFLJob succeed, curIteration: " + flLiteClient.getIteration()));
|
||||
|
||||
// create mask
|
||||
curStatus = client.getFeatureMask();
|
||||
curStatus = flLiteClient.getFeatureMask();
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
restart("[Encrypt] creatMask", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
|
||||
restart("[Encrypt] creatMask", flLiteClient.getNextRequestTime(), flLiteClient);
|
||||
continue;
|
||||
} else if (curStatus == FLClientStatus.FAILED) {
|
||||
failed("[Encrypt] createMask", client.getIteration(), client.getRetCode(), curStatus);
|
||||
failed("[Encrypt] createMask", flLiteClient);
|
||||
break;
|
||||
}
|
||||
|
||||
// train
|
||||
curStatus = client.localTrain();
|
||||
curStatus = flLiteClient.localTrain();
|
||||
if (curStatus == FLClientStatus.FAILED) {
|
||||
failed("[train] train", client.getIteration(), client.getRetCode(), curStatus);
|
||||
failed("[train] train", flLiteClient);
|
||||
break;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[train] train succeed"));
|
||||
|
||||
// updateModel
|
||||
curStatus = updateModel(client);
|
||||
curStatus = updateModel(flLiteClient);
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
restart("[updateModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
|
||||
restart("[updateModel]", flLiteClient.getNextRequestTime(), flLiteClient);
|
||||
continue;
|
||||
} else if (curStatus == FLClientStatus.FAILED) {
|
||||
failed("[updateModel] updateModel", client.getIteration(), client.getRetCode(), curStatus);
|
||||
failed("[updateModel] updateModel", flLiteClient);
|
||||
break;
|
||||
}
|
||||
|
||||
// unmasking
|
||||
curStatus = client.unMasking();
|
||||
curStatus = flLiteClient.unMasking();
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
restart("[Encrypt] unmasking", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
|
||||
restart("[Encrypt] unmasking", flLiteClient.getNextRequestTime(), flLiteClient);
|
||||
continue;
|
||||
} else if (curStatus == FLClientStatus.FAILED) {
|
||||
failed("[Encrypt] unmasking", client.getIteration(), client.getRetCode(), curStatus);
|
||||
failed("[Encrypt] unmasking", flLiteClient);
|
||||
break;
|
||||
}
|
||||
|
||||
// getModel
|
||||
curStatus = getModel(client);
|
||||
curStatus = getModel(flLiteClient);
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
restart("[getModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
|
||||
restart("[getModel]", flLiteClient.getNextRequestTime(), flLiteClient);
|
||||
continue;
|
||||
} else if (curStatus == FLClientStatus.FAILED) {
|
||||
failed("[getModel] getModel", client.getIteration(), client.getRetCode(), curStatus);
|
||||
failed("[getModel] getModel", flLiteClient);
|
||||
break;
|
||||
}
|
||||
|
||||
// get the feature map after averaging and update dp_norm_clip
|
||||
updateDpNormClip(client);
|
||||
flLiteClient.updateDpNormClip();
|
||||
|
||||
// evaluate model after getting model from server
|
||||
if (flParameter.getTestDataset().equals("null")) {
|
||||
LOGGER.info(Common.addTag("[evaluate] the testDataset is null, don't evaluate the model after getting" +
|
||||
" model from server"));
|
||||
if (!checkEvalPath()) {
|
||||
LOGGER.info(Common.addTag("[evaluate] the data map set by user do not contain evaluation dataset, " +
|
||||
"don't evaluate the model after getting model from server"));
|
||||
} else {
|
||||
curStatus = client.evaluateModel();
|
||||
curStatus = flLiteClient.evaluateModel();
|
||||
if (curStatus == FLClientStatus.FAILED) {
|
||||
failed("[evaluate] evaluate", client.getIteration(), client.getRetCode(), curStatus);
|
||||
failed("[evaluate] evaluate", flLiteClient);
|
||||
break;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluate succeed"));
|
||||
}
|
||||
LOGGER.info(Common.addTag("========================================================the total response of "
|
||||
+ client.getIteration() + ": " + curStatus +
|
||||
+ flLiteClient.getIteration() + ": " + curStatus +
|
||||
"======================================================================"));
|
||||
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(),
|
||||
client.getRetCode());
|
||||
} while (client.getIteration() < client.getIterations());
|
||||
client.freeSession();
|
||||
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), flLiteClient.getIteration(),
|
||||
flLiteClient.getRetCode());
|
||||
Common.freeSession();
|
||||
} while (flLiteClient.getIteration() < flLiteClient.getIterations());
|
||||
LOGGER.info(Common.addTag("flJobRun finish"));
|
||||
flJobResultCallback.onFlJobFinished(flParameter.getFlName(), client.getIterations(), client.getRetCode());
|
||||
flJobResultCallback.onFlJobFinished(flParameter.getFlName(), flLiteClient.getIterations(),
|
||||
flLiteClient.getRetCode());
|
||||
return curStatus;
|
||||
}
|
||||
|
||||
private FLClientStatus startFLJob(FLLiteClient client) {
|
||||
FLClientStatus curStatus = client.startFLJob();
|
||||
private FLClientStatus startFLJob(FLLiteClient flLiteClient) {
|
||||
FLClientStatus curStatus = flLiteClient.startFLJob();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
waitSomeTime();
|
||||
curStatus = client.startFLJob();
|
||||
curStatus = flLiteClient.startFLJob();
|
||||
}
|
||||
return curStatus;
|
||||
}
|
||||
|
||||
private FLClientStatus updateModel(FLLiteClient client) {
|
||||
FLClientStatus curStatus = client.updateModel();
|
||||
private FLClientStatus updateModel(FLLiteClient flLiteClient) {
|
||||
FLClientStatus curStatus = flLiteClient.updateModel();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
waitSomeTime();
|
||||
curStatus = client.updateModel();
|
||||
curStatus = flLiteClient.updateModel();
|
||||
}
|
||||
return curStatus;
|
||||
}
|
||||
|
||||
private FLClientStatus getModel(FLLiteClient client) {
|
||||
FLClientStatus curStatus = client.getModel();
|
||||
private FLClientStatus getModel(FLLiteClient flLiteClient) {
|
||||
FLClientStatus curStatus = flLiteClient.getModel();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
waitSomeTime();
|
||||
curStatus = client.getModel();
|
||||
curStatus = flLiteClient.getModel();
|
||||
}
|
||||
return curStatus;
|
||||
}
|
||||
|
||||
private void updateDpNormClip(FLLiteClient client) {
|
||||
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel();
|
||||
if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
|
||||
int currentIter = client.getIteration();
|
||||
Map<String, float[]> fedFeatureMap = getFeatureMap();
|
||||
float fedWeightUpdateNorm = calWeightUpdateNorm(oldFeatureMap, fedFeatureMap);
|
||||
if (fedWeightUpdateNorm == -1) {
|
||||
LOGGER.severe(Common.addTag("[updateDpNormClip] the returned value fedWeightUpdateNorm is not valid: " +
|
||||
"-1, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
private boolean checkEvalPath() {
|
||||
boolean tag = true;
|
||||
if (Common.checkFLName(flParameter.getFlName())) {
|
||||
if ("null".equals(flParameter.getTestDataset())) {
|
||||
tag = false;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[DP] L2-norm of weights' average update is: " + fedWeightUpdateNorm));
|
||||
float newNormCLip = (float) client.getDpNormClipFactor() * fedWeightUpdateNorm;
|
||||
if (currentIter == 1) {
|
||||
client.setDpNormClipAdapt(newNormCLip);
|
||||
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
|
||||
} else {
|
||||
if (newNormCLip < client.getDpNormClipAdapt()) {
|
||||
client.setDpNormClipAdapt(newNormCLip);
|
||||
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
|
||||
}
|
||||
}
|
||||
LOGGER.info(Common.addTag("[DP] Adaptive dpNormClip is: " + client.getDpNormClipAdapt()));
|
||||
return tag;
|
||||
}
|
||||
if (!flParameter.getDataMap().containsKey(RunType.EVALMODE)) {
|
||||
LOGGER.info(Common.addTag("[evaluate] the data map set by user do not contain evaluation dataset, " +
|
||||
"don't evaluate the model after getting model from server"));
|
||||
tag = false;
|
||||
return tag;
|
||||
}
|
||||
return tag;
|
||||
}
|
||||
|
||||
private boolean checkStopJobFlag() {
|
||||
if (localFLParameter.isStopJobFlag()) {
|
||||
LOGGER.info(Common.addTag("the stopJObFlag is set to true, the job will be stop"));
|
||||
curStatus = FLClientStatus.FAILED;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private void getOldFeatureMap(FLLiteClient client) {
|
||||
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel();
|
||||
if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
|
||||
Map<String, float[]> featureMap = getFeatureMap();
|
||||
oldFeatureMap = client.getOldMapCopy(featureMap);
|
||||
}
|
||||
}
|
||||
|
||||
private float calWeightUpdateNorm(Map<String, float[]> originalData, Map<String, float[]> newData) {
|
||||
float updateL2Norm = 0f;
|
||||
for (String key : originalData.keySet()) {
|
||||
float[] data = originalData.get(key);
|
||||
float[] dataAfterUpdate = newData.get(key);
|
||||
for (int j = 0; j < data.length; j++) {
|
||||
if (j >= dataAfterUpdate.length) {
|
||||
LOGGER.severe("[calWeightUpdateNorm] the index j is out of range for array dataAfterUpdate, " +
|
||||
"please check");
|
||||
return -1;
|
||||
}
|
||||
float updateData = data[j] - dataAfterUpdate[j];
|
||||
updateL2Norm += updateData * updateData;
|
||||
}
|
||||
}
|
||||
updateL2Norm = (float) Math.sqrt(updateL2Norm);
|
||||
return updateL2Norm;
|
||||
}
|
||||
|
||||
private Map<String, float[]> getFeatureMap() {
|
||||
Map<String, float[]> featureMap = new HashMap<>();
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
}
|
||||
return featureMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* Starts an inference task on the device.
|
||||
|
@ -265,6 +265,292 @@ public class SyncFLJob {
|
|||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public int[] modelInference() {
|
||||
if (Common.checkFLName(flParameter.getFlName())) {
|
||||
return deprecatedModelInference();
|
||||
}
|
||||
|
||||
Client client = ClientManager.getClient(flParameter.getFlName());
|
||||
localFLParameter.setMsConfig(0, flParameter.getThreadNum(), flParameter.getCpuBindMode(), false);
|
||||
localFLParameter.setStopJobFlag(false);
|
||||
int[] labels = new int[0];
|
||||
Map<RunType, Integer> dataSize = client.initDataSets(flParameter.getDataMap());
|
||||
if (dataSize.isEmpty()) {
|
||||
LOGGER.severe("[model inference] initDataSets failed, please check");
|
||||
return new int[0];
|
||||
}
|
||||
Status tag = client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig());
|
||||
if (!Status.SUCCESS.equals(tag)) {
|
||||
LOGGER.severe(Common.addTag("[model inference] unsolved error code in <initSessionAndInputs>: the return " +
|
||||
" status is: " + tag));
|
||||
return new int[0];
|
||||
}
|
||||
client.setBatchSize(flParameter.getBatchSize());
|
||||
LOGGER.info(Common.addTag("===========model inference============="));
|
||||
labels = client.inferModel().stream().mapToInt(Integer::valueOf).toArray();
|
||||
if (labels == null || labels.length == 0) {
|
||||
LOGGER.severe("[model inference] the returned label from client.inferModel() is null, please " +
|
||||
"check");
|
||||
}
|
||||
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
|
||||
client.free();
|
||||
LOGGER.info(Common.addTag("[model inference] inference finish"));
|
||||
return labels;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtains the latest model on the cloud.
|
||||
*
|
||||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public FLClientStatus getModel() {
|
||||
if (!Common.checkFLName(flParameter.getFlName()) && ANDROID.equals(flParameter.getDeployEnv())) {
|
||||
Common.setSecureRandom(Common.getFastSecureRandom());
|
||||
} else {
|
||||
Common.setSecureRandom(new SecureRandom());
|
||||
}
|
||||
if (Common.checkFLName(flParameter.getFlName())) {
|
||||
return deprecatedGetModel();
|
||||
}
|
||||
localFLParameter.setServerMod(flParameter.getServerMod().toString());
|
||||
localFLParameter.setMsConfig(0, 1, 0, false);
|
||||
FLClientStatus status;
|
||||
FLLiteClient flLiteClient = new FLLiteClient();
|
||||
status = flLiteClient.getModel();
|
||||
return status;
|
||||
}
|
||||
|
||||
/**
|
||||
* use to stop FL job.
|
||||
*/
|
||||
public void stopFLJob() {
|
||||
LOGGER.info(Common.addTag("will stop the flJob"));
|
||||
localFLParameter.setStopJobFlag(true);
|
||||
Common.notifyObject();
|
||||
}
|
||||
|
||||
private void waitSomeTime() {
|
||||
if (flParameter.getSleepTime() != 0) {
|
||||
Common.sleep(flParameter.getSleepTime());
|
||||
} else {
|
||||
Common.sleep(SLEEP_TIME);
|
||||
}
|
||||
}
|
||||
|
||||
private void waitNextReqTime(String nextReqTime) {
|
||||
long waitTime = Common.getWaitTime(nextReqTime);
|
||||
Common.sleep(waitTime);
|
||||
}
|
||||
|
||||
private void restart(String tag, String nextReqTime, FLLiteClient flLiteClient) {
|
||||
LOGGER.info(Common.addTag(tag + " out of time: need wait and request startFLJob again"));
|
||||
waitNextReqTime(nextReqTime);
|
||||
Common.freeSession();
|
||||
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), flLiteClient.getIteration(),
|
||||
flLiteClient.getRetCode());
|
||||
}
|
||||
|
||||
private void failed(String tag, FLLiteClient flLiteClient) {
|
||||
LOGGER.info(Common.addTag(tag + " failed"));
|
||||
LOGGER.info(Common.addTag("=========================================the total response of " +
|
||||
flLiteClient.getIteration() + ": " + curStatus + "========================================="));
|
||||
Common.freeSession();
|
||||
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), flLiteClient.getIteration(),
|
||||
flLiteClient.getRetCode());
|
||||
}
|
||||
|
||||
private static Map<RunType, List<String>> createDatasetMap(String trainDataPath, String evalDataPath,
|
||||
String inferDataPath, String pathRegex) {
|
||||
Map<RunType, List<String>> dataMap = new HashMap<>();
|
||||
if ((trainDataPath == null) || ("null".equals(trainDataPath)) || (trainDataPath.isEmpty())) {
|
||||
LOGGER.info(Common.addTag("the trainDataPath is null or empty, please check if you are in the case of " +
|
||||
"noly inference"));
|
||||
} else {
|
||||
dataMap.put(RunType.TRAINMODE, Arrays.asList(trainDataPath.split(pathRegex)));
|
||||
LOGGER.info(Common.addTag("the trainDataPath: " + Arrays.toString(trainDataPath.split(pathRegex))));
|
||||
}
|
||||
|
||||
if ((evalDataPath == null) || ("null".equals(evalDataPath)) || (evalDataPath.isEmpty())) {
|
||||
LOGGER.info(Common.addTag("the evalDataPath is null or empty, please check if you are in the case of only" +
|
||||
" trainning without evaluation"));
|
||||
} else {
|
||||
dataMap.put(RunType.EVALMODE, Arrays.asList(evalDataPath.split(pathRegex)));
|
||||
LOGGER.info(Common.addTag("the evalDataPath: " + Arrays.toString(evalDataPath.split(pathRegex))));
|
||||
}
|
||||
|
||||
if ((inferDataPath == null) || ("null".equals(inferDataPath)) || (inferDataPath.isEmpty())) {
|
||||
LOGGER.info(Common.addTag("the inferDataPath is null or empty, please check if you are in the case of " +
|
||||
"trainning without inference"));
|
||||
} else {
|
||||
dataMap.put(RunType.INFERMODE, Arrays.asList(inferDataPath.split(pathRegex)));
|
||||
LOGGER.info(Common.addTag("the inferDataPath: " + Arrays.toString(inferDataPath.split(pathRegex))));
|
||||
}
|
||||
return dataMap;
|
||||
}
|
||||
|
||||
private static void createWeightNameList(String trainWeightName, String inferWeightName, String nameRegex,
|
||||
FLParameter flParameter) {
|
||||
if ((trainWeightName == null) || ("null".equals(trainWeightName)) || (trainWeightName.isEmpty())) {
|
||||
LOGGER.info(Common.addTag("the trainWeightName is null or empty, only need in " + ServerMod.HYBRID_TRAINING));
|
||||
} else {
|
||||
flParameter.setHybridWeightName(Arrays.asList(trainWeightName.split(nameRegex)), RunType.TRAINMODE);
|
||||
LOGGER.info(Common.addTag("the trainWeightName: " + Arrays.toString(trainWeightName.split(nameRegex))));
|
||||
}
|
||||
|
||||
if ((inferWeightName == null) || ("null".equals(inferWeightName)) || (inferWeightName.isEmpty())) {
|
||||
LOGGER.info(Common.addTag("the inferWeightName is null or empty, only need in " + ServerMod.HYBRID_TRAINING));
|
||||
} else {
|
||||
flParameter.setHybridWeightName(Arrays.asList(inferWeightName.split(nameRegex)), RunType.INFERMODE);
|
||||
LOGGER.info(Common.addTag("the trainWeightName: " + Arrays.toString(inferWeightName.split(nameRegex))));
|
||||
}
|
||||
}
|
||||
|
||||
private static void task(String[] args) {
|
||||
String trainDataPath = args[0];
|
||||
String evalDataPath = args[1];
|
||||
String inferDataPath = args[2];
|
||||
String pathRegex = args[3];
|
||||
|
||||
String flName = args[4];
|
||||
String trainModelPath = args[5];
|
||||
String inferModelPath = args[6];
|
||||
String sslProtocol = args[7];
|
||||
String deployEnv = args[8];
|
||||
String domainName = args[9];
|
||||
String certPath = args[10];
|
||||
boolean useElb = Boolean.parseBoolean(args[11]);
|
||||
int serverNum = Integer.parseInt(args[12]);
|
||||
String task = args[13];
|
||||
int threadNum = Integer.parseInt(args[14]);
|
||||
String cpuBindMode = args[15];
|
||||
String trainWeightName = args[16];
|
||||
String inferWeightName = args[17];
|
||||
String nameRegex = args[18];
|
||||
String serverMod = args[19];
|
||||
int batchSize = Integer.parseInt(args[20]);
|
||||
|
||||
FLParameter flParameter = FLParameter.getInstance();
|
||||
|
||||
// create dataset of map
|
||||
Map<RunType, List<String>> dataMap = createDatasetMap(trainDataPath, evalDataPath, inferDataPath, pathRegex);
|
||||
|
||||
// create weight name of list
|
||||
createWeightNameList(trainWeightName, inferWeightName, nameRegex, flParameter);
|
||||
|
||||
flParameter.setFlName(flName);
|
||||
SyncFLJob syncFLJob = new SyncFLJob();
|
||||
Common.setIsHttps(domainName.split("//")[0].split(":")[0]);
|
||||
switch (task) {
|
||||
case "train":
|
||||
LOGGER.info(Common.addTag("start syncFLJob.flJobRun()"));
|
||||
if (Common.isHttps()) {
|
||||
flParameter.setCertPath(certPath);
|
||||
}
|
||||
flParameter.setDataMap(dataMap);
|
||||
flParameter.setTrainModelPath(trainModelPath);
|
||||
flParameter.setInferModelPath(inferModelPath);
|
||||
flParameter.setSslProtocol(sslProtocol);
|
||||
flParameter.setDeployEnv(deployEnv);
|
||||
flParameter.setDomainName(domainName);
|
||||
flParameter.setUseElb(useElb);
|
||||
flParameter.setServerNum(serverNum);
|
||||
flParameter.setThreadNum(threadNum);
|
||||
flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode));
|
||||
flParameter.setBatchSize(batchSize);
|
||||
syncFLJob.flJobRun();
|
||||
break;
|
||||
case "inference":
|
||||
LOGGER.info(Common.addTag("start syncFLJob.modelInference()"));
|
||||
flParameter.setDataMap(dataMap);
|
||||
flParameter.setInferModelPath(inferModelPath);
|
||||
flParameter.setThreadNum(threadNum);
|
||||
flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode));
|
||||
flParameter.setBatchSize(batchSize);
|
||||
syncFLJob.modelInference();
|
||||
break;
|
||||
case "getModel":
|
||||
LOGGER.info(Common.addTag("start syncFLJob.getModel()"));
|
||||
if (Common.isHttps()) {
|
||||
flParameter.setCertPath(certPath);
|
||||
}
|
||||
flParameter.setTrainModelPath(trainModelPath);
|
||||
flParameter.setInferModelPath(inferModelPath);
|
||||
flParameter.setSslProtocol(sslProtocol);
|
||||
flParameter.setDeployEnv(deployEnv);
|
||||
flParameter.setDomainName(domainName);
|
||||
flParameter.setUseElb(useElb);
|
||||
flParameter.setServerNum(serverNum);
|
||||
flParameter.setServerMod(ServerMod.valueOf(serverMod));
|
||||
syncFLJob.getModel();
|
||||
break;
|
||||
default:
|
||||
LOGGER.info(Common.addTag("do not do any thing!"));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private static void deprecatedTask(String[] args) {
|
||||
String trainDataset = args[0];
|
||||
String vocabFile = args[1];
|
||||
String idsFile = args[2];
|
||||
String testDataset = args[3];
|
||||
String flName = args[4];
|
||||
String trainModelPath = args[5];
|
||||
String inferModelPath = args[6];
|
||||
boolean useSSL = Boolean.parseBoolean(args[7]);
|
||||
String domainName = args[8];
|
||||
boolean useElb = Boolean.parseBoolean(args[9]);
|
||||
int serverNum = Integer.parseInt(args[10]);
|
||||
String certPath = args[11];
|
||||
String task = args[12];
|
||||
|
||||
FLParameter flParameter = FLParameter.getInstance();
|
||||
flParameter.setFlName(flName);
|
||||
SyncFLJob syncFLJob = new SyncFLJob();
|
||||
Common.setIsHttps(domainName.split("//")[0].split(":")[0]);
|
||||
switch (task) {
|
||||
case "train":
|
||||
if (Common.isHttps()) {
|
||||
flParameter.setCertPath(certPath);
|
||||
}
|
||||
flParameter.setTrainDataset(trainDataset);
|
||||
flParameter.setTrainModelPath(trainModelPath);
|
||||
flParameter.setTestDataset(testDataset);
|
||||
flParameter.setInferModelPath(inferModelPath);
|
||||
flParameter.setDomainName(domainName);
|
||||
flParameter.setUseElb(useElb);
|
||||
flParameter.setServerNum(serverNum);
|
||||
if (ALBERT.equals(flName)) {
|
||||
flParameter.setVocabFile(vocabFile);
|
||||
flParameter.setIdsFile(idsFile);
|
||||
}
|
||||
syncFLJob.flJobRun();
|
||||
break;
|
||||
case "inference":
|
||||
flParameter.setTestDataset(testDataset);
|
||||
flParameter.setInferModelPath(inferModelPath);
|
||||
if (ALBERT.equals(flName)) {
|
||||
flParameter.setVocabFile(vocabFile);
|
||||
flParameter.setIdsFile(idsFile);
|
||||
}
|
||||
syncFLJob.modelInference();
|
||||
break;
|
||||
case "getModel":
|
||||
if (Common.isHttps()) {
|
||||
flParameter.setCertPath(certPath);
|
||||
}
|
||||
flParameter.setTrainModelPath(trainModelPath);
|
||||
flParameter.setInferModelPath(inferModelPath);
|
||||
flParameter.setDomainName(domainName);
|
||||
flParameter.setUseElb(useElb);
|
||||
flParameter.setServerNum(serverNum);
|
||||
syncFLJob.getModel();
|
||||
break;
|
||||
default:
|
||||
LOGGER.info(Common.addTag("do not do any thing!"));
|
||||
}
|
||||
}
|
||||
|
||||
private int[] deprecatedModelInference() {
|
||||
int[] labels = new int[0];
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
|
@ -290,166 +576,27 @@ public class SyncFLJob {
|
|||
LOGGER.info(Common.addTag("[model inference] inference finish"));
|
||||
}
|
||||
return labels;
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtains the latest model on the cloud.
|
||||
*
|
||||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public FLClientStatus getModel() {
|
||||
Common.setSecureRandom(Common.getFastSecureRandom());
|
||||
int tag = 0;
|
||||
private FLClientStatus deprecatedGetModel() {
|
||||
localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString());
|
||||
FLClientStatus status;
|
||||
try {
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
localFLParameter.setServerMod(ServerMod.HYBRID_TRAINING.toString());
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " +
|
||||
flParameter.getTrainModelPath() + " Create Train Session============="));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the " +
|
||||
"return is -1"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " +
|
||||
flParameter.getInferModelPath() + " Create inference Session============="));
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString());
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " +
|
||||
flParameter.getTrainModelPath() + " Create Train Session============="));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||
}
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return " +
|
||||
"is -1"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
FLCommunication flCommunication = FLCommunication.getInstance();
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||
flParameter.getDomainName());
|
||||
GetModel getModelBuf = GetModel.getInstance();
|
||||
byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), 0);
|
||||
byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
|
||||
if (!Common.isSeverReady(message)) {
|
||||
LOGGER.info(Common.addTag("[getModel] the server is not ready now, need wait some time and request " +
|
||||
"again"));
|
||||
status = FLClientStatus.WAIT;
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] get model request success"));
|
||||
ByteBuffer debugBuffer = ByteBuffer.wrap(message);
|
||||
ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer);
|
||||
status = getModelBuf.doResponse(responseDataBuf);
|
||||
LOGGER.info(Common.addTag("[getModel] success!"));
|
||||
} catch (Exception ex) {
|
||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + ex.getMessage()));
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
LOGGER.info(Common.addTag("===========free train session============="));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
SessionUtil.free(alTrainBert.getTrainSession());
|
||||
LOGGER.info(Common.addTag("===========free inference session============="));
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
SessionUtil.free(alInferBert.getTrainSession());
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("===========free session============="));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
SessionUtil.free(trainLenet.getTrainSession());
|
||||
}
|
||||
FLLiteClient flLiteClient = new FLLiteClient();
|
||||
status = flLiteClient.getModel();
|
||||
return status;
|
||||
}
|
||||
|
||||
private void waitSomeTime() {
|
||||
if (flParameter.getSleepTime() != 0) {
|
||||
Common.sleep(flParameter.getSleepTime());
|
||||
} else {
|
||||
Common.sleep(SLEEP_TIME);
|
||||
}
|
||||
}
|
||||
|
||||
private void waitNextReqTime(String nextReqTime) {
|
||||
long waitTime = Common.getWaitTime(nextReqTime);
|
||||
Common.sleep(waitTime);
|
||||
}
|
||||
|
||||
private void restart(String tag, String nextReqTime, int iteration, int retcode) {
|
||||
LOGGER.info(Common.addTag(tag + " out of time: need wait and request startFLJob again"));
|
||||
waitNextReqTime(nextReqTime);
|
||||
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), iteration, retcode);
|
||||
}
|
||||
|
||||
private void failed(String tag, int iteration, int retcode, FLClientStatus curStatus) {
|
||||
LOGGER.info(Common.addTag(tag + " failed"));
|
||||
LOGGER.info(Common.addTag("=========================================the total response of " +
|
||||
iteration + ": " + curStatus + "========================================="));
|
||||
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), iteration, retcode);
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
String trainDataset = args[0];
|
||||
String vocabFile = args[1];
|
||||
String idsFile = args[2];
|
||||
String testDataset = args[3];
|
||||
String flName = args[4];
|
||||
String trainModelPath = args[5];
|
||||
String inferModelPath = args[6];
|
||||
boolean useSSL = Boolean.parseBoolean(args[7]);
|
||||
String domainName = args[8];
|
||||
boolean useElb = Boolean.parseBoolean(args[9]);
|
||||
int serverNum = Integer.parseInt(args[10]);
|
||||
String certPath = args[11];
|
||||
String task = args[12];
|
||||
|
||||
FLParameter flParameter = FLParameter.getInstance();
|
||||
|
||||
SyncFLJob syncFLJob = new SyncFLJob();
|
||||
if (task.equals("train")) {
|
||||
if (useSSL) {
|
||||
flParameter.setCertPath(certPath);
|
||||
}
|
||||
flParameter.setTrainDataset(trainDataset);
|
||||
flParameter.setFlName(flName);
|
||||
flParameter.setTrainModelPath(trainModelPath);
|
||||
flParameter.setTestDataset(testDataset);
|
||||
flParameter.setInferModelPath(inferModelPath);
|
||||
flParameter.setUseSSL(useSSL);
|
||||
flParameter.setDomainName(domainName);
|
||||
flParameter.setUseElb(useElb);
|
||||
flParameter.setServerNum(serverNum);
|
||||
if (ALBERT.equals(flName)) {
|
||||
flParameter.setVocabFile(vocabFile);
|
||||
flParameter.setIdsFile(idsFile);
|
||||
}
|
||||
syncFLJob.flJobRun();
|
||||
} else if (task.equals("inference")) {
|
||||
flParameter.setFlName(flName);
|
||||
flParameter.setTestDataset(testDataset);
|
||||
flParameter.setInferModelPath(inferModelPath);
|
||||
if (ALBERT.equals(flName)) {
|
||||
flParameter.setVocabFile(vocabFile);
|
||||
flParameter.setIdsFile(idsFile);
|
||||
}
|
||||
syncFLJob.modelInference();
|
||||
} else if (task.equals("getModel")) {
|
||||
if (useSSL) {
|
||||
flParameter.setCertPath(certPath);
|
||||
}
|
||||
flParameter.setFlName(flName);
|
||||
flParameter.setTrainModelPath(trainModelPath);
|
||||
flParameter.setInferModelPath(inferModelPath);
|
||||
flParameter.setUseSSL(useSSL);
|
||||
flParameter.setDomainName(domainName);
|
||||
flParameter.setUseElb(useElb);
|
||||
flParameter.setServerNum(serverNum);
|
||||
syncFLJob.getModel();
|
||||
} else {
|
||||
LOGGER.info(Common.addTag("do not do any thing!"));
|
||||
if (args[4] == null || args[4].isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("the parameter of <args[4]> is null, please check"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
if (Common.checkFLName(args[4])) {
|
||||
deprecatedTask(args);
|
||||
} else {
|
||||
task(args);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,14 +16,16 @@
|
|||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.Client;
|
||||
import com.mindspore.flclient.model.ClientManager;
|
||||
import com.mindspore.flclient.model.CommonUtils;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.Status;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import com.mindspore.lite.MSTensor;
|
||||
|
||||
import mindspore.schema.FeatureMap;
|
||||
import mindspore.schema.RequestUpdateModel;
|
||||
|
@ -33,9 +35,13 @@ import mindspore.schema.ResponseUpdateModel;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Date;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
/**
|
||||
* Define the serialization method, handle the response message returned from server for updateModel request.
|
||||
*
|
||||
|
@ -52,6 +58,7 @@ public class UpdateModel {
|
|||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private FLClientStatus status;
|
||||
private int retCode = ResponseCode.RequestError;
|
||||
|
||||
private UpdateModel() {
|
||||
}
|
||||
|
@ -78,6 +85,10 @@ public class UpdateModel {
|
|||
return status;
|
||||
}
|
||||
|
||||
public int getRetCode() {
|
||||
return retCode;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a flatBuffer builder of RequestUpdateModel.
|
||||
*
|
||||
|
@ -88,7 +99,16 @@ public class UpdateModel {
|
|||
*/
|
||||
public byte[] getRequestUpdateFLJob(int iteration, SecureProtocol secureProtocol, int trainDataSize) {
|
||||
RequestUpdateModelBuilder builder = new RequestUpdateModelBuilder(localFLParameter.getEncryptLevel());
|
||||
return builder.flName(flParameter.getFlName()).time().id(localFLParameter.getFlID())
|
||||
boolean isPkiVerify = flParameter.isPkiVerify();
|
||||
if (isPkiVerify) {
|
||||
Date date = new Date();
|
||||
long timestamp = date.getTime();
|
||||
String dateTime = String.valueOf(timestamp);
|
||||
byte[] signature = CipherClient.signTimeAndIter(dateTime, iteration);
|
||||
return builder.flName(flParameter.getFlName()).time(dateTime).id(localFLParameter.getFlID())
|
||||
.featuresMap(secureProtocol, trainDataSize).iteration(iteration).signData(signature).build();
|
||||
}
|
||||
return builder.flName(flParameter.getFlName()).time("null").id(localFLParameter.getFlID())
|
||||
.featuresMap(secureProtocol, trainDataSize).iteration(iteration).build();
|
||||
}
|
||||
|
||||
|
@ -99,8 +119,9 @@ public class UpdateModel {
|
|||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public FLClientStatus doResponse(ResponseUpdateModel response) {
|
||||
LOGGER.info(Common.addTag("[updateModel] ==========updateModel response================"));
|
||||
LOGGER.info(Common.addTag("[updateModel] ==========retcode: " + response.retcode()));
|
||||
retCode = response.retcode();
|
||||
LOGGER.info(Common.addTag("[updateModel] ==========the response message of updateModel is================"));
|
||||
LOGGER.info(Common.addTag("[updateModel] ==========retCode: " + retCode));
|
||||
LOGGER.info(Common.addTag("[updateModel] ==========reason: " + response.reason()));
|
||||
LOGGER.info(Common.addTag("[updateModel] ==========next request time: " + response.nextReqTime()));
|
||||
switch (response.retcode()) {
|
||||
|
@ -120,6 +141,65 @@ public class UpdateModel {
|
|||
}
|
||||
}
|
||||
|
||||
private Map<String, float[]> getFeatureMap() {
|
||||
Client client = ClientManager.getClient(flParameter.getFlName());
|
||||
status = Common.initSession(flParameter.getTrainModelPath());
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
retCode = ResponseCode.RequestError;
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
List<MSTensor> features = client.getFeatures();
|
||||
Map<String, float[]> trainedMap = CommonUtils.convertTensorToFeatures(features);
|
||||
LOGGER.info(Common.addTag("[updateModel] ===========free session============="));
|
||||
Common.freeSession();
|
||||
if (trainedMap.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[updateModel] the return trainedMap is empty in <CommonUtils" +
|
||||
".convertTensorToFeatures>"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
status = FLClientStatus.FAILED;
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return trainedMap;
|
||||
}
|
||||
|
||||
private Map<String, float[]> deprecatedGetFeatureMap() {
|
||||
status = Common.initSession(flParameter.getTrainModelPath());
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
retCode = ResponseCode.RequestError;
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " +
|
||||
flParameter.getFlName()));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
||||
if (map.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil" +
|
||||
".convertTensorToFeatures>"));
|
||||
status = FLClientStatus.FAILED;
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " +
|
||||
flParameter.getFlName()));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
if (map.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil" +
|
||||
".convertTensorToFeatures>"));
|
||||
status = FLClientStatus.FAILED;
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[updateModel] the flName is not valid"));
|
||||
status = FLClientStatus.FAILED;
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
Common.freeSession();
|
||||
return map;
|
||||
}
|
||||
|
||||
class RequestUpdateModelBuilder {
|
||||
private RequestUpdateModel requestUM;
|
||||
private FlatBufferBuilder builder;
|
||||
|
@ -127,6 +207,7 @@ public class UpdateModel {
|
|||
private int nameOffset = 0;
|
||||
private int idOffset = 0;
|
||||
private int timestampOffset = 0;
|
||||
private int signDataOffset = 0;
|
||||
private int iteration = 0;
|
||||
private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT;
|
||||
|
||||
|
@ -153,12 +234,22 @@ public class UpdateModel {
|
|||
/**
|
||||
* Serialize the element timestamp in RequestUpdateModel.
|
||||
*
|
||||
* @param setTime current timestamp when the request is sent.
|
||||
* @return the RequestUpdateModelBuilder object.
|
||||
*/
|
||||
private RequestUpdateModelBuilder time() {
|
||||
Date date = new Date();
|
||||
long time = date.getTime();
|
||||
this.timestampOffset = builder.createString(String.valueOf(time));
|
||||
private RequestUpdateModelBuilder time(String setTime) {
|
||||
if (setTime == null || setTime.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[updateModel] the parameter of <setTime> is null or empty, please " +
|
||||
"check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
if (setTime.equals("null")) {
|
||||
Date date = new Date();
|
||||
long time = date.getTime();
|
||||
this.timestampOffset = builder.createString(String.valueOf(time));
|
||||
} else {
|
||||
this.timestampOffset = builder.createString(setTime);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -189,10 +280,16 @@ public class UpdateModel {
|
|||
}
|
||||
|
||||
private RequestUpdateModelBuilder featuresMap(SecureProtocol secureProtocol, int trainDataSize) {
|
||||
ArrayList<String> encryptFeatureName = secureProtocol.getEncryptFeatureName();
|
||||
ArrayList<String> updateFeatureName = secureProtocol.getUpdateFeatureName();
|
||||
Map<String, float[]> trainedMap = new HashMap<String, float[]>();
|
||||
if (Common.checkFLName(flParameter.getFlName())) {
|
||||
trainedMap = deprecatedGetFeatureMap();
|
||||
} else {
|
||||
trainedMap = getFeatureMap();
|
||||
}
|
||||
switch (encryptLevel) {
|
||||
case PW_ENCRYPT:
|
||||
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize);
|
||||
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize, trainedMap);
|
||||
if (fmOffsetsPW == null || fmOffsetsPW.length == 0) {
|
||||
LOGGER.severe("[Encrypt] the return fmOffsetsPW from <secureProtocol.pwMaskModel> is " +
|
||||
"null, please check");
|
||||
|
@ -202,10 +299,12 @@ public class UpdateModel {
|
|||
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
|
||||
return this;
|
||||
case DP_ENCRYPT:
|
||||
int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize);
|
||||
int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize, trainedMap);
|
||||
if (fmOffsetsDP == null || fmOffsetsDP.length == 0) {
|
||||
LOGGER.severe("[Encrypt] the return fmOffsetsDP from <secureProtocol.dpMaskModel> is " +
|
||||
"null, please check");
|
||||
retCode = ResponseCode.RequestError;
|
||||
status = FLClientStatus.FAILED;
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP);
|
||||
|
@ -213,36 +312,11 @@ public class UpdateModel {
|
|||
return this;
|
||||
case NOT_ENCRYPT:
|
||||
default:
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " +
|
||||
flParameter.getFlName()));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
||||
if (map.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil" +
|
||||
".convertTensorToFeatures>"));
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " +
|
||||
flParameter.getFlName()));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
if (map.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil" +
|
||||
".convertTensorToFeatures>"));
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[updateModel] the flName is not valid"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
int featureSize = encryptFeatureName.size();
|
||||
int featureSize = updateFeatureName.size();
|
||||
int[] fmOffsets = new int[featureSize];
|
||||
for (int i = 0; i < featureSize; i++) {
|
||||
String key = encryptFeatureName.get(i);
|
||||
float[] data = map.get(key);
|
||||
String key = updateFeatureName.get(i);
|
||||
float[] data = trainedMap.get(key);
|
||||
LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature " +
|
||||
"size: " + data.length));
|
||||
for (int j = 0; j < data.length; j++) {
|
||||
|
@ -258,6 +332,22 @@ public class UpdateModel {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize the element signature in RequestUpdateModel.
|
||||
*
|
||||
* @param signData the signature Data.
|
||||
* @return the RequestUpdateModelBuilder object.
|
||||
*/
|
||||
private RequestUpdateModelBuilder signData(byte[] signData) {
|
||||
if (signData == null || signData.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[updateModel] the parameter of <signData> is null or empty, please " +
|
||||
"check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.signDataOffset = RequestUpdateModel.createSignatureVector(builder, signData);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a flatBuffer builder of RequestUpdateModel.
|
||||
*
|
||||
|
@ -270,6 +360,7 @@ public class UpdateModel {
|
|||
RequestUpdateModel.addTimestamp(builder, this.timestampOffset);
|
||||
RequestUpdateModel.addIteration(builder, this.iteration);
|
||||
RequestUpdateModel.addFeatureMap(builder, this.fmOffset);
|
||||
RequestUpdateModel.addSignature(builder, this.signDataOffset);
|
||||
int root = RequestUpdateModel.endRequestUpdateModel(builder);
|
||||
builder.finish(root);
|
||||
return builder.sizedByteArray();
|
||||
|
|
|
@ -21,7 +21,6 @@ import static com.mindspore.flclient.LocalFLParameter.KEY_LEN;
|
|||
|
||||
import com.mindspore.flclient.Common;
|
||||
|
||||
import java.io.UnsupportedEncodingException;
|
||||
import java.security.InvalidAlgorithmParameterException;
|
||||
import java.security.InvalidKeyException;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
|
|
|
@ -0,0 +1,411 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.mindspore.flclient.cipher;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
import com.mindspore.flclient.FLParameter;
|
||||
|
||||
import org.bouncycastle.asn1.ASN1OctetString;
|
||||
import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier;
|
||||
import org.bouncycastle.asn1.x509.SubjectKeyIdentifier;
|
||||
import org.bouncycastle.util.encoders.Hex;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.security.InvalidKeyException;
|
||||
import java.security.KeyStore;
|
||||
import java.security.KeyStoreException;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.security.NoSuchProviderException;
|
||||
import java.security.PublicKey;
|
||||
import java.security.SignatureException;
|
||||
import java.security.UnrecoverableEntryException;
|
||||
import java.security.cert.CRLException;
|
||||
import java.security.cert.Certificate;
|
||||
import java.security.cert.CertificateException;
|
||||
import java.security.cert.CertificateFactory;
|
||||
import java.security.cert.X509CRL;
|
||||
import java.security.cert.X509CRLEntry;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.Base64;
|
||||
import java.util.Date;
|
||||
import java.util.Set;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* Certificate verification class
|
||||
*
|
||||
* @since 2021-8-27
|
||||
*/
|
||||
public class CertVerify {
|
||||
private static final Logger LOGGER = Logger.getLogger(CertVerify.class.toString());
|
||||
|
||||
/**
|
||||
* Verify the legitimacy of certificate chain
|
||||
*
|
||||
* @param clientID clientID of this client
|
||||
* @param x509Certificates certificate chain
|
||||
* @return verification result
|
||||
*/
|
||||
public static boolean verifyCertificateChain(String clientID, X509Certificate[] x509Certificates) {
|
||||
if (clientID == null || clientID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] the parameter clientID is null or empty, please check!"));
|
||||
return false;
|
||||
}
|
||||
if (x509Certificates == null || x509Certificates.length < 2) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] the parameter x509Certificates is null or the length is not " +
|
||||
"valid: < 2, please check!"));
|
||||
return false;
|
||||
}
|
||||
if (verifyChain(clientID, x509Certificates) && verifyCommonName(clientID, x509Certificates)
|
||||
&& verifyCrl(clientID, x509Certificates) && verifyValidDate(x509Certificates) &&
|
||||
verifyKeyIdentifier(clientID, x509Certificates)) {
|
||||
LOGGER.info(Common.addTag("[CertVerify] verifyCertificateChain success!"));
|
||||
return true;
|
||||
}
|
||||
LOGGER.severe(Common.addTag("[CertVerify] verifyCertificateChain failed!"));
|
||||
return false;
|
||||
}
|
||||
|
||||
private static boolean verifyCommonName(String clientID, X509Certificate[] x509Certificate) {
|
||||
if (clientID == null || clientID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] the parameter clientID is null or empty, please check!"));
|
||||
return false;
|
||||
}
|
||||
if (x509Certificate == null || x509Certificate.length < 2) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] x509Certificate chains is null or the length is not valid: < 2," +
|
||||
" please check!"));
|
||||
return false;
|
||||
}
|
||||
X509Certificate[] certificateChains = getX509CertificateChain(clientID);
|
||||
if (certificateChains == null || certificateChains.length < 4) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] certificateChains is null or the length is not valid: < 4, " +
|
||||
"please check!"));
|
||||
return false;
|
||||
}
|
||||
X509Certificate localEquipCACert = certificateChains[2];
|
||||
// get subjectDN of local root equipment CA certificate
|
||||
String localEquipCAName = localEquipCACert.getSubjectDN().getName();
|
||||
// get issueDN of client's equipment certificate
|
||||
X509Certificate remoteEquipCert = x509Certificate[1];
|
||||
String equipIssueName = remoteEquipCert.getIssuerDN().getName();
|
||||
return localEquipCAName.equals(equipIssueName);
|
||||
}
|
||||
|
||||
// check whether the former certificate owner is the publisher of next one.
|
||||
private static boolean verifyChain(String clientID, X509Certificate[] x509Certificates) {
|
||||
if (x509Certificates == null || x509Certificates.length < 2) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] certificateChains is null or the length is not valid: < 2, " +
|
||||
"please check!"));
|
||||
return false;
|
||||
}
|
||||
|
||||
// check remote equipment certificate
|
||||
try {
|
||||
X509Certificate[] certificateChains = getX509CertificateChain(clientID);
|
||||
if (certificateChains == null || certificateChains.length < 3) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] certificateChains is null or the length is not valid: < 3, " +
|
||||
"please check!"));
|
||||
return false;
|
||||
}
|
||||
X509Certificate localEquipCA = certificateChains[2];
|
||||
PublicKey publicKey = localEquipCA.getPublicKey();
|
||||
x509Certificates[1].verify(publicKey);
|
||||
} catch (NoSuchProviderException | CertificateException | NoSuchAlgorithmException |
|
||||
InvalidKeyException | SignatureException e) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] catch Exception: " + e.getMessage()));
|
||||
return false;
|
||||
}
|
||||
|
||||
// check remote service certificate
|
||||
X509Certificate remoteEquipCert = x509Certificates[1];
|
||||
X509Certificate remoteServiceCert = x509Certificates[0];
|
||||
|
||||
try {
|
||||
remoteEquipCert.checkValidity();
|
||||
remoteServiceCert.checkValidity();
|
||||
} catch (java.security.cert.CertificateExpiredException |
|
||||
java.security.cert.CertificateNotYetValidException e) {
|
||||
e.printStackTrace();
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
PublicKey publicKey = remoteEquipCert.getPublicKey();
|
||||
remoteServiceCert.verify(publicKey);
|
||||
} catch (CertificateException | NoSuchAlgorithmException | InvalidKeyException |
|
||||
NoSuchProviderException | SignatureException e) {
|
||||
LOGGER.severe(Common.addTag("verifyChain failed!"));
|
||||
LOGGER.severe(Common.addTag("[verifyChain] catch Exception: " + e.getMessage()));
|
||||
return false;
|
||||
}
|
||||
LOGGER.severe(Common.addTag("verifyChain success!"));
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* get certificate chain according to clientID
|
||||
*
|
||||
* @param clientID clientID of this client
|
||||
* @return certificate chain
|
||||
*/
|
||||
public static X509Certificate[] getX509CertificateChain(String clientID) {
|
||||
if (clientID == null || clientID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] the parameter clientID is null or empty, please check!"));
|
||||
return null;
|
||||
}
|
||||
X509Certificate[] x509Certificates = null;
|
||||
try {
|
||||
Certificate[] certificates = null;
|
||||
KeyStore keyStore = KeyStore.getInstance(CipherConsts.KEYSTORE_TYPE);
|
||||
keyStore.load(null);
|
||||
KeyStore.Entry entry = keyStore.getEntry(clientID, null);
|
||||
if (entry == null || !(entry instanceof KeyStore.PrivateKeyEntry)) {
|
||||
return null;
|
||||
}
|
||||
certificates = ((KeyStore.PrivateKeyEntry) entry).getCertificateChain();
|
||||
if (certificates == null) {
|
||||
return null;
|
||||
}
|
||||
x509Certificates = (X509Certificate[]) certificates;
|
||||
} catch (IOException | NoSuchAlgorithmException | UnrecoverableEntryException | KeyStoreException |
|
||||
CertificateException e) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] catch Exception: " + e.getMessage()));
|
||||
}
|
||||
return x509Certificates;
|
||||
}
|
||||
|
||||
/**
|
||||
* transform pem format to X509 format
|
||||
*
|
||||
* @param pemCerts pem format certificate
|
||||
* @return X509 format certificates
|
||||
*/
|
||||
public static X509Certificate[] transformPemArrayToX509Array(String[] pemCerts) {
|
||||
if (pemCerts == null || pemCerts.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] pemCerts is null or empty, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
int nSize = pemCerts.length;
|
||||
X509Certificate[] x509Certificates = new X509Certificate[nSize];
|
||||
for (int i = 0; i < nSize; ++i) {
|
||||
x509Certificates[i] = transformPemToX509(pemCerts[i]);
|
||||
}
|
||||
return x509Certificates;
|
||||
}
|
||||
|
||||
private static X509Certificate transformPemToX509(String pemCert) {
|
||||
X509Certificate x509Certificate = null;
|
||||
CertificateFactory cf;
|
||||
try {
|
||||
if (pemCert != null && !pemCert.trim().isEmpty()) {
|
||||
byte[] certificateData = Base64.getDecoder().decode(pemCert);
|
||||
cf = CertificateFactory.getInstance("X509");
|
||||
x509Certificate = (X509Certificate) cf.generateCertificate(new ByteArrayInputStream(certificateData));
|
||||
}
|
||||
} catch (CertificateException e) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] catch Exception: " + e.getMessage()));
|
||||
return null;
|
||||
}
|
||||
return x509Certificate;
|
||||
}
|
||||
|
||||
private static boolean verifyCrl(String clientID, X509Certificate[] x509Certificates) {
|
||||
if (x509Certificates == null || x509Certificates.length < 2) {
|
||||
LOGGER.severe(Common.addTag("[verifyCrl] the number of certificate in x509Certificates is less than 2, " +
|
||||
"please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
FLParameter flParameter = FLParameter.getInstance();
|
||||
X509Certificate equipCert = x509Certificates[1];
|
||||
if (equipCert == null) {
|
||||
LOGGER.severe(Common.addTag("[verifyCrl] equipCert is null, please check it!"));
|
||||
return false;
|
||||
}
|
||||
String equipCertSerialNum = equipCert.getSerialNumber().toString();
|
||||
if (verifySingleCrl(clientID, equipCertSerialNum, flParameter.getEquipCrlPath())) {
|
||||
LOGGER.info(Common.addTag("[verifyCrl] verify crl certificate success!"));
|
||||
return true;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[verifyCrl] verify crl certificate failed!"));
|
||||
return false;
|
||||
}
|
||||
|
||||
private static boolean verifySingleCrl(String clientID, String caSerialNumber, String crlPath) {
|
||||
if (caSerialNumber == null || caSerialNumber.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] caSerialNumber is null or empty, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
// crlPath does not exist
|
||||
if (crlPath.equals("null")) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] crlPath is null, please set crlPath with setEquipCrlPath " +
|
||||
"method!"));
|
||||
return false;
|
||||
}
|
||||
boolean notInFlag = true;
|
||||
try {
|
||||
X509CRL crl = (X509CRL) readCrl(crlPath);
|
||||
if (crl != null) {
|
||||
// check CRL cert with local equipment CA publicKey
|
||||
X509Certificate[] certificateChains = getX509CertificateChain(clientID);
|
||||
if (certificateChains == null || certificateChains.length < 3) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] certificateChains is null or the length is not" +
|
||||
" valid: < 3, please check!"));
|
||||
return false;
|
||||
}
|
||||
X509Certificate localEquipCA = certificateChains[2];
|
||||
PublicKey publicKey = localEquipCA.getPublicKey();
|
||||
crl.verify(publicKey);
|
||||
|
||||
// check whether remote equipmentCert in CRL
|
||||
Set<?> set = crl.getRevokedCertificates();
|
||||
if (set == null) {
|
||||
LOGGER.info(Common.addTag("[verifySingleCrl] verifyCrl Revoked Cert list is null"));
|
||||
return true;
|
||||
}
|
||||
for (Object obj : set) {
|
||||
X509CRLEntry crlEntity = (X509CRLEntry) obj;
|
||||
if (crlEntity.getSerialNumber().toString().equals(caSerialNumber)) {
|
||||
LOGGER.info(Common.addTag("[verifySingleCrl] Find same SerialNumber during the crl!"));
|
||||
notInFlag = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (java.security.cert.CRLException | java.security.NoSuchAlgorithmException |
|
||||
java.security.InvalidKeyException | java.security.NoSuchProviderException |
|
||||
java.security.SignatureException e) {
|
||||
LOGGER.severe(Common.addTag("[verifySingleCrl] judgeCAInCRL error: " + e.getMessage()));
|
||||
notInFlag = false;
|
||||
}
|
||||
return notInFlag;
|
||||
}
|
||||
|
||||
private static Object readCrl(String assetName) {
|
||||
if (assetName == null || assetName.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[readCrl] the parameter of <assetName> is null or empty, please check!"));
|
||||
return null;
|
||||
}
|
||||
InputStream inputStream = null;
|
||||
try {
|
||||
inputStream = new FileInputStream(assetName);
|
||||
} catch (IOException e) {
|
||||
LOGGER.severe(Common.addTag("[readCrl] catch Exception of read inputStream in readCert: " +
|
||||
e.getMessage()));
|
||||
return null;
|
||||
}
|
||||
Object crlCert = null;
|
||||
try {
|
||||
CertificateFactory cf = CertificateFactory.getInstance("X.509");
|
||||
crlCert = cf.generateCRL(inputStream);
|
||||
} catch (CertificateException | CRLException e) {
|
||||
LOGGER.severe(Common.addTag("[readCrl] catch Exception of creating CertificateFactory in readCert: "
|
||||
+ e.getMessage()));
|
||||
} finally {
|
||||
try {
|
||||
inputStream.close();
|
||||
} catch (IOException e) {
|
||||
LOGGER.severe(Common.addTag("[readCrl] catch Exception of close inputStream: "
|
||||
+ e.getMessage()));
|
||||
}
|
||||
}
|
||||
|
||||
return crlCert;
|
||||
}
|
||||
|
||||
private static boolean verifyValidDate(X509Certificate[] x509Certificates) {
|
||||
if (x509Certificates == null) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] x509Certificates is null, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
Date date = new Date();
|
||||
try {
|
||||
int nSize = x509Certificates.length;
|
||||
for (int i = 0; i < nSize; ++i) {
|
||||
x509Certificates[i].checkValidity(date);
|
||||
}
|
||||
} catch (java.security.cert.CertificateExpiredException |
|
||||
java.security.cert.CertificateNotYetValidException e) {
|
||||
LOGGER.severe(Common.addTag("[verifyValidDate] catch Exception: " + e.getMessage()));
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private static boolean verifyKeyIdentifier(String clientID, X509Certificate[] x509Certificates) {
|
||||
if (clientID == null || clientID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] the parameter clientID is null or empty, please check!"));
|
||||
return false;
|
||||
}
|
||||
if (x509Certificates == null || x509Certificates.length < 2) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] x509Certificate chains is null or the length is not valid: < 2," +
|
||||
" please check!"));
|
||||
return false;
|
||||
}
|
||||
|
||||
X509Certificate[] certificateChains = getX509CertificateChain(clientID);
|
||||
if (certificateChains == null || certificateChains.length < 3) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] certificateChains is null or the length is not valid: < 3, " +
|
||||
"please check!"));
|
||||
return false;
|
||||
}
|
||||
|
||||
X509Certificate localEquipCACert = certificateChains[2];
|
||||
String subjectIdentifier = "null";
|
||||
|
||||
// get subjectKeyIdentifier of local root equipment CA certificate
|
||||
try {
|
||||
String subjectIdentifierOid = "2.5.29.14";
|
||||
byte[] subjectExtendData = localEquipCACert.getExtensionValue(subjectIdentifierOid);
|
||||
ASN1OctetString asn1OctetString = ASN1OctetString.getInstance(subjectExtendData);
|
||||
byte[] tmpData = asn1OctetString.getOctets();
|
||||
SubjectKeyIdentifier subjectKeyIdentifier = SubjectKeyIdentifier.getInstance(tmpData);
|
||||
byte[] octKeyIdentifier = subjectKeyIdentifier.getKeyIdentifier();
|
||||
subjectIdentifier = new String(Hex.encode(octKeyIdentifier));
|
||||
} catch (ExceptionInInitializerError e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
// get authorityKeyIdentifier of client's equipment certificate
|
||||
X509Certificate remoteEquipCert = x509Certificates[1];
|
||||
String authorityIdentifier = "null";
|
||||
try {
|
||||
if (remoteEquipCert == null) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] remoteEquipCert is null, please check it!"));
|
||||
return false;
|
||||
}
|
||||
String authorityIdentifierOid = "2.5.29.35";
|
||||
byte[] authExtendData = remoteEquipCert.getExtensionValue(authorityIdentifierOid);
|
||||
ASN1OctetString asn1OctetString = ASN1OctetString.getInstance(authExtendData);
|
||||
byte[] tmpData = asn1OctetString.getOctets();
|
||||
AuthorityKeyIdentifier authorityKeyIdentifier = AuthorityKeyIdentifier.getInstance(tmpData);
|
||||
byte[] octKeyIdentifier = authorityKeyIdentifier.getKeyIdentifier();
|
||||
authorityIdentifier = new String(Hex.encode(octKeyIdentifier));
|
||||
} catch (ExceptionInInitializerError e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
if (authorityIdentifier.equals("null") || subjectIdentifier.equals("null")) {
|
||||
LOGGER.severe(Common.addTag("[CertVerify] authorityKeyIdentifier or subjectKeyIdentifier is null, check " +
|
||||
"failed!"));
|
||||
return false;
|
||||
} else {
|
||||
return authorityIdentifier.equals(subjectIdentifier);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.mindspore.flclient.cipher;
|
||||
|
||||
/**
|
||||
* Consts used for certification signature
|
||||
*
|
||||
* @since 2021-8-27
|
||||
*/
|
||||
public class CipherConsts {
|
||||
/**
|
||||
* provider name
|
||||
*/
|
||||
public static final String PROVIDER_NAME = "HwUniversalKeyStoreProvider";
|
||||
|
||||
/**
|
||||
* keyStore type
|
||||
*/
|
||||
public static final String KEYSTORE_TYPE = "HwKeyStore";
|
||||
|
||||
/**
|
||||
* sign algorithm
|
||||
*/
|
||||
public static final String SIGN_ALGORITHM = "SHA256withRSA/PSS";
|
||||
}
|
|
@ -21,6 +21,7 @@ import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
|||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
|
||||
import com.mindspore.flclient.CipherClient;
|
||||
import com.mindspore.flclient.Common;
|
||||
import com.mindspore.flclient.FLClientStatus;
|
||||
import com.mindspore.flclient.FLCommunication;
|
||||
|
@ -91,12 +92,28 @@ public class ClientListReq {
|
|||
long timestamp = date.getTime();
|
||||
String dateTime = String.valueOf(timestamp);
|
||||
int time = builder.createString(dateTime);
|
||||
|
||||
GetClientList.startGetClientList(builder);
|
||||
GetClientList.addFlId(builder, id);
|
||||
GetClientList.addIteration(builder, iteration);
|
||||
GetClientList.addTimestamp(builder, time);
|
||||
int clientListRoot = GetClientList.endGetClientList(builder);
|
||||
int clientListRoot;
|
||||
byte[] signature = CipherClient.signTimeAndIter(dateTime, iteration);
|
||||
if (signature == null) {
|
||||
LOGGER.severe(Common.addTag("[getClientList] get signature is null!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
if (signature.length > 0) {
|
||||
int signed = GetClientList.createSignatureVector(builder, signature);
|
||||
GetClientList.startGetClientList(builder);
|
||||
GetClientList.addFlId(builder, id);
|
||||
GetClientList.addIteration(builder, iteration);
|
||||
GetClientList.addTimestamp(builder, time);
|
||||
GetClientList.addSignature(builder, signed);
|
||||
clientListRoot = GetClientList.endGetClientList(builder);
|
||||
} else {
|
||||
GetClientList.startGetClientList(builder);
|
||||
GetClientList.addFlId(builder, id);
|
||||
GetClientList.addIteration(builder, iteration);
|
||||
GetClientList.addTimestamp(builder, time);
|
||||
GetClientList.addSignature(builder, 0);
|
||||
clientListRoot = GetClientList.endGetClientList(builder);
|
||||
}
|
||||
builder.finish(clientListRoot);
|
||||
byte[] msg = builder.sizedByteArray();
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), flParameter.getDomainName());
|
||||
|
@ -107,6 +124,7 @@ public class ClientListReq {
|
|||
"request again"));
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
retCode = ResponseCode.OutOfTime;
|
||||
return FLClientStatus.RESTART;
|
||||
}
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
|
|
|
@ -20,6 +20,7 @@ import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
|||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
|
||||
import com.mindspore.flclient.CipherClient;
|
||||
import com.mindspore.flclient.Common;
|
||||
import com.mindspore.flclient.FLClientStatus;
|
||||
import com.mindspore.flclient.FLCommunication;
|
||||
|
@ -68,7 +69,8 @@ public class ReconstructSecretReq {
|
|||
*/
|
||||
public FLClientStatus sendReconstructSecret(List<DecryptShareSecrets> decryptShareSecretsList,
|
||||
List<String> u3ClientList, int iteration) {
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), flParameter.getDomainName());
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||
flParameter.getDomainName());
|
||||
FlatBufferBuilder builder = new FlatBufferBuilder();
|
||||
int desFlId = builder.createString(localFLParameter.getFlID());
|
||||
Date date = new Date();
|
||||
|
@ -106,12 +108,28 @@ public class ReconstructSecretReq {
|
|||
}
|
||||
int reconstructShareSecrets = SendReconstructSecret.createReconstructSecretSharesVector(builder,
|
||||
decryptShareList);
|
||||
SendReconstructSecret.startSendReconstructSecret(builder);
|
||||
SendReconstructSecret.addFlId(builder, desFlId);
|
||||
SendReconstructSecret.addReconstructSecretShares(builder, reconstructShareSecrets);
|
||||
SendReconstructSecret.addIteration(builder, iteration);
|
||||
SendReconstructSecret.addTimestamp(builder, time);
|
||||
int reconstructSecretRoot = SendReconstructSecret.endSendReconstructSecret(builder);
|
||||
|
||||
int reconstructSecretRoot;
|
||||
byte[] signature = CipherClient.signTimeAndIter(dateTime, iteration);
|
||||
if (signature.length > 0) {
|
||||
int signed = SendReconstructSecret.createSignatureVector(builder, signature);
|
||||
SendReconstructSecret.startSendReconstructSecret(builder);
|
||||
SendReconstructSecret.addFlId(builder, desFlId);
|
||||
SendReconstructSecret.addReconstructSecretShares(builder, reconstructShareSecrets);
|
||||
SendReconstructSecret.addIteration(builder, iteration);
|
||||
SendReconstructSecret.addTimestamp(builder, time);
|
||||
SendReconstructSecret.addSignature(builder, signed);
|
||||
reconstructSecretRoot = SendReconstructSecret.endSendReconstructSecret(builder);
|
||||
} else {
|
||||
SendReconstructSecret.startSendReconstructSecret(builder);
|
||||
SendReconstructSecret.addFlId(builder, desFlId);
|
||||
SendReconstructSecret.addReconstructSecretShares(builder, reconstructShareSecrets);
|
||||
SendReconstructSecret.addIteration(builder, iteration);
|
||||
SendReconstructSecret.addTimestamp(builder, time);
|
||||
SendReconstructSecret.addSignature(builder, 0);
|
||||
reconstructSecretRoot = SendReconstructSecret.endSendReconstructSecret(builder);
|
||||
}
|
||||
|
||||
builder.finish(reconstructSecretRoot);
|
||||
byte[] msg = builder.sizedByteArray();
|
||||
try {
|
||||
|
@ -121,6 +139,7 @@ public class ReconstructSecretReq {
|
|||
"time and request again"));
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
retCode = ResponseCode.OutOfTime;
|
||||
return FLClientStatus.RESTART;
|
||||
}
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.cipher;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.security.InvalidKeyException;
|
||||
import java.security.Key;
|
||||
import java.security.KeyStore;
|
||||
import java.security.KeyStoreException;
|
||||
import java.security.MessageDigest;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.security.NoSuchProviderException;
|
||||
import java.security.PrivateKey;
|
||||
import java.security.PublicKey;
|
||||
import java.security.Signature;
|
||||
import java.security.SignatureException;
|
||||
import java.security.UnrecoverableKeyException;
|
||||
import java.security.cert.CertificateException;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* class used for sign data and verify data
|
||||
*
|
||||
* @since 2021-8-27
|
||||
*/
|
||||
public class SignAndVerify {
|
||||
private static final Logger LOGGER = Logger.getLogger(SignAndVerify.class.toString());
|
||||
|
||||
/**
|
||||
* sign data
|
||||
*
|
||||
* @param clientID ID of this client
|
||||
* @param data data need to be signed
|
||||
* @return signed data
|
||||
*/
|
||||
public static byte[] signData(String clientID, byte[] data) {
|
||||
if (clientID == null || clientID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[SignAndVerify] the parameter clientID is null or empty, please check!"));
|
||||
return null;
|
||||
}
|
||||
if (data == null || data.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[SignAndVerify] the parameter data is null or empty, please check!"));
|
||||
return null;
|
||||
}
|
||||
byte[] signData = null;
|
||||
try {
|
||||
KeyStore ks = KeyStore.getInstance(CipherConsts.KEYSTORE_TYPE);
|
||||
ks.load(null);
|
||||
Key privateKey = ks.getKey(clientID, null);
|
||||
if (privateKey == null) {
|
||||
LOGGER.info("private key is null");
|
||||
return null;
|
||||
}
|
||||
Signature signature = Signature.getInstance(CipherConsts.SIGN_ALGORITHM, CipherConsts.PROVIDER_NAME);
|
||||
signature.initSign((PrivateKey) privateKey);
|
||||
signature.update(data);
|
||||
signData = signature.sign();
|
||||
} catch (KeyStoreException | CertificateException | NoSuchAlgorithmException | IOException
|
||||
| UnrecoverableKeyException | NoSuchProviderException | InvalidKeyException
|
||||
| SignatureException e) {
|
||||
LOGGER.severe(Common.addTag("[SignAndVerify] catch Exception: " + e.getMessage()));
|
||||
}
|
||||
return signData;
|
||||
}
|
||||
|
||||
/**
|
||||
* verify signature with certifications
|
||||
*
|
||||
* @param clientID ID of this client
|
||||
* @param x509Certificates certificates
|
||||
* @param data original data
|
||||
* @param signed signed data
|
||||
* @return verify result
|
||||
*/
|
||||
public static boolean verifySignatureByCert(String clientID, X509Certificate[] x509Certificates, byte[] data,
|
||||
byte[] signed) {
|
||||
if (clientID == null || clientID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[SignAndVerify] the parameter clientID is null or empty, please check!"));
|
||||
return false;
|
||||
}
|
||||
if (x509Certificates == null || x509Certificates.length < 1) {
|
||||
LOGGER.severe(Common.addTag("[SignAndVerify] the parameter x509Certificates is null or the length is not " +
|
||||
"valid: < 1, please check!"));
|
||||
return false;
|
||||
}
|
||||
if (data == null || data.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[SignAndVerify] the parameter data is null or empty, please check!"));
|
||||
return false;
|
||||
}
|
||||
if (signed == null || signed.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[SignAndVerify] the parameter signed is null or empty, please check!"));
|
||||
return false;
|
||||
}
|
||||
if (!CertVerify.verifyCertificateChain(clientID, x509Certificates)) {
|
||||
LOGGER.info(Common.addTag("Verify chain failed!"));
|
||||
return false;
|
||||
}
|
||||
LOGGER.info(Common.addTag("Verify chain success!"));
|
||||
|
||||
boolean isValid;
|
||||
try {
|
||||
if (x509Certificates[0].getPublicKey() == null) {
|
||||
LOGGER.severe(Common.addTag("[SignAndVerify] get public key failed!"));
|
||||
return false;
|
||||
}
|
||||
PublicKey publicKey = x509Certificates[0].getPublicKey(); // get public key
|
||||
Signature signature = Signature.getInstance(CipherConsts.SIGN_ALGORITHM);
|
||||
signature.initVerify(publicKey); // set public key
|
||||
signature.update(data); // set data
|
||||
isValid = signature.verify(signed); // verify the consistence between signature and data
|
||||
} catch (NoSuchAlgorithmException | SignatureException | InvalidKeyException e) {
|
||||
LOGGER.severe(Common.addTag("[SignAndVerify] catch Exception: " + e.getMessage()));
|
||||
return false;
|
||||
}
|
||||
return isValid;
|
||||
}
|
||||
|
||||
/**
|
||||
* get hash result of bytes
|
||||
*
|
||||
* @param bytes inputs
|
||||
* @return hash value of bytes
|
||||
*/
|
||||
public static byte[] getSHA256(byte[] bytes) {
|
||||
MessageDigest messageDigest;
|
||||
byte[] hash = new byte[0];
|
||||
try {
|
||||
messageDigest = MessageDigest.getInstance("SHA-256");
|
||||
hash = messageDigest.digest(bytes);
|
||||
} catch (NoSuchAlgorithmException e) {
|
||||
LOGGER.severe(Common.addTag("[PkiUtil] catch NoSuchAlgorithmException: " + e.getMessage()));
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
}
|
|
@ -40,7 +40,8 @@ public class ClientPublicKey {
|
|||
*/
|
||||
public String getFlID() {
|
||||
if (flID == null || flID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[ClientPublicKey] the parameter of <flID> is null, please set it before use"));
|
||||
LOGGER.severe(Common.addTag("[ClientPublicKey] the parameter of <flID> is null, please set it before " +
|
||||
"using"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return flID;
|
||||
|
|
|
@ -38,7 +38,7 @@ public class EncryptShare {
|
|||
public String getFlID() {
|
||||
if (flID == null || flID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[DecryptShareSecrets] the parameter of <flID> is null, please set it before " +
|
||||
"use"));
|
||||
"using"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return flID;
|
||||
|
|
|
@ -20,7 +20,6 @@ package com.mindspore.flclient.cipher.struct;
|
|||
* class used define new array type
|
||||
*
|
||||
* @param <T> an array
|
||||
*
|
||||
* @since 2021-8-27
|
||||
*/
|
||||
public class NewArray<T> {
|
||||
|
|
|
@ -38,7 +38,7 @@ public class ShareSecret {
|
|||
*/
|
||||
public String getFlID() {
|
||||
if (flID == null || flID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[ShareSecret] the parameter of <flID> is null, please set it before use"));
|
||||
LOGGER.severe(Common.addTag("[ShareSecret] the parameter of <flID> is null, please set it before using"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return flID;
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.pki;
|
||||
|
||||
import java.security.cert.Certificate;
|
||||
|
||||
/**
|
||||
* PkiBean entity
|
||||
*
|
||||
* @since 2021-08-25
|
||||
*/
|
||||
public class PkiBean {
|
||||
private byte[] signData;
|
||||
|
||||
private Certificate[] certificates;
|
||||
|
||||
public PkiBean(byte[] signData, Certificate[] certificates) {
|
||||
this.signData = signData;
|
||||
this.certificates = certificates;
|
||||
}
|
||||
|
||||
public byte[] getSignData() {
|
||||
return signData;
|
||||
}
|
||||
|
||||
public void setSignData(byte[] signData) {
|
||||
this.signData = signData;
|
||||
}
|
||||
|
||||
public Certificate[] getCertificates() {
|
||||
return certificates;
|
||||
}
|
||||
|
||||
public void setCertificates(Certificate[] certificates) {
|
||||
this.certificates = certificates;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.pki;
|
||||
|
||||
/**
|
||||
* Pki consts
|
||||
*
|
||||
* @since 2021-08-25
|
||||
*/
|
||||
public class PkiConsts {
|
||||
/**
|
||||
* PROVIDER_NAME
|
||||
*/
|
||||
public static final String PROVIDER_NAME = "HwUniversalKeyStoreProvider";
|
||||
|
||||
/**
|
||||
* KEYSTORE_TYPE
|
||||
*/
|
||||
public static final String KEYSTORE_TYPE = "HwKeyStore";
|
||||
|
||||
/**
|
||||
* ALGORITHM
|
||||
*/
|
||||
public static final String ALGORITHM = "SHA256withRSA/PSS";
|
||||
}
|
|
@ -0,0 +1,247 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.pki;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
import com.mindspore.flclient.LocalFLParameter;
|
||||
|
||||
import org.bouncycastle.util.io.pem.PemObject;
|
||||
import org.bouncycastle.util.io.pem.PemWriter;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.StringWriter;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.security.InvalidKeyException;
|
||||
import java.security.Key;
|
||||
import java.security.KeyStore;
|
||||
import java.security.KeyStoreException;
|
||||
import java.security.MessageDigest;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.security.NoSuchProviderException;
|
||||
import java.security.PrivateKey;
|
||||
import java.security.Signature;
|
||||
import java.security.SignatureException;
|
||||
import java.security.UnrecoverableEntryException;
|
||||
import java.security.UnrecoverableKeyException;
|
||||
import java.security.cert.Certificate;
|
||||
import java.security.cert.CertificateEncodingException;
|
||||
import java.security.cert.CertificateException;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.Locale;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* Pki Util
|
||||
*
|
||||
* @since 2021-08-25
|
||||
*/
|
||||
public class PkiUtil {
|
||||
private static final Logger LOGGER = Logger.getLogger(PkiUtil.class.toString());
|
||||
|
||||
/**
|
||||
* generate PkiBean
|
||||
*
|
||||
* @param clientID String
|
||||
* @param time long
|
||||
* @return PkiBean
|
||||
*/
|
||||
public static PkiBean genPkiBean(String clientID, long time) {
|
||||
String sourceData = LocalFLParameter.getInstance().getFlID() + " " + time;
|
||||
byte[] signDataBytes = PkiUtil.signData(clientID,
|
||||
sourceData.getBytes(StandardCharsets.UTF_8));
|
||||
|
||||
Certificate[] certificates = PkiUtil.getCertificateChain(clientID);
|
||||
|
||||
return new PkiBean(signDataBytes, certificates);
|
||||
}
|
||||
|
||||
/**
|
||||
* get str of SHA256
|
||||
*
|
||||
* @param str String
|
||||
* @return hash
|
||||
*/
|
||||
public static byte[] getSHA256Str(String str) {
|
||||
MessageDigest messageDigest;
|
||||
byte[] hash = new byte[0];
|
||||
try {
|
||||
messageDigest = MessageDigest.getInstance("SHA-256");
|
||||
hash = messageDigest.digest(str.getBytes(StandardCharsets.UTF_8));
|
||||
} catch (NoSuchAlgorithmException e) {
|
||||
LOGGER.severe(Common.addTag("[PkiUtil] catch NoSuchAlgorithmException: " + e.getMessage()));
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
|
||||
/**
|
||||
* get Pem format
|
||||
*
|
||||
* @param certificate Certificate
|
||||
* @return String result
|
||||
* @throws IOException e
|
||||
*/
|
||||
public static String getPemFormat(Certificate certificate) throws IOException {
|
||||
StringWriter writer = new StringWriter();
|
||||
PemWriter pemWriter = new PemWriter(writer);
|
||||
try {
|
||||
pemWriter.writeObject(new PemObject("CERTIFICATE", certificate.getEncoded()));
|
||||
} catch (IOException | CertificateEncodingException e) {
|
||||
LOGGER.severe(Common.addTag("[PkiUtil] catch IOException or CertificateEncodingException in getPermFormat: "
|
||||
+ e.getMessage()));
|
||||
} finally {
|
||||
pemWriter.flush();
|
||||
pemWriter.close();
|
||||
}
|
||||
|
||||
return writer.toString();
|
||||
}
|
||||
|
||||
private static byte[] signData(String clientID, byte[] data) {
|
||||
byte[] signData = null;
|
||||
try {
|
||||
KeyStore ks = KeyStore.getInstance(PkiConsts.KEYSTORE_TYPE);
|
||||
ks.load(null);
|
||||
Key privateKey = ks.getKey(clientID, null);
|
||||
if (privateKey == null) {
|
||||
return new byte[0];
|
||||
}
|
||||
Signature signature = Signature.getInstance(PkiConsts.ALGORITHM, PkiConsts.PROVIDER_NAME);
|
||||
if (privateKey instanceof PrivateKey) {
|
||||
signature.initSign((PrivateKey) privateKey);
|
||||
}
|
||||
signature.update(data);
|
||||
signData = signature.sign();
|
||||
} catch (KeyStoreException | CertificateException | NoSuchAlgorithmException | IOException
|
||||
| UnrecoverableKeyException | NoSuchProviderException | InvalidKeyException
|
||||
| SignatureException e) {
|
||||
LOGGER.severe(Common.addTag("[PkiUtil] catch Exception: " + e.getMessage()));
|
||||
}
|
||||
return signData;
|
||||
}
|
||||
|
||||
/**
|
||||
* get certificate chain
|
||||
*
|
||||
* @param clientID String
|
||||
* @return Certificate[]
|
||||
*/
|
||||
public static Certificate[] getCertificateChain(String clientID) {
|
||||
Certificate[] certificates = null;
|
||||
try {
|
||||
KeyStore keyStore = KeyStore.getInstance(PkiConsts.KEYSTORE_TYPE);
|
||||
keyStore.load(null);
|
||||
|
||||
KeyStore.Entry entry = keyStore.getEntry(clientID, null);
|
||||
if (entry == null) {
|
||||
return new Certificate[0];
|
||||
}
|
||||
|
||||
if (!(entry instanceof KeyStore.PrivateKeyEntry)) {
|
||||
return new Certificate[0];
|
||||
}
|
||||
|
||||
certificates = ((KeyStore.PrivateKeyEntry) entry).getCertificateChain();
|
||||
} catch (IOException | CertificateException | NoSuchAlgorithmException
|
||||
| UnrecoverableEntryException | KeyStoreException e) {
|
||||
LOGGER.severe(Common.addTag("[PkiUtil] catch Exception: " + e.getMessage()));
|
||||
}
|
||||
|
||||
return certificates;
|
||||
}
|
||||
|
||||
/**
|
||||
* to hex format
|
||||
*
|
||||
* @param data byte[]
|
||||
* @return String
|
||||
*/
|
||||
public static String toHexFormat(byte[] data) {
|
||||
if (data == null || data.length == 0) {
|
||||
return "";
|
||||
}
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for (byte byteData : data) {
|
||||
sb.append(String.format(Locale.ROOT, "%02x", byteData));
|
||||
}
|
||||
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* gen equip cert hash
|
||||
*
|
||||
* @param clientID String
|
||||
* @return String
|
||||
*/
|
||||
public static String genEquipCertHash(String clientID) {
|
||||
String equipCert;
|
||||
byte[] equipCertBytesHash = null;
|
||||
try {
|
||||
Certificate[] certificates = getCertificateChain(clientID);
|
||||
if (certificates == null || certificates.length < 2) {
|
||||
return "";
|
||||
}
|
||||
equipCert = readPemFormat(certificates[1]);
|
||||
equipCertBytesHash = getSHA256Str(equipCert);
|
||||
} catch (IOException e) {
|
||||
LOGGER.severe(Common.addTag("[PkiUtil] catch Exception: " + e.getMessage()));
|
||||
}
|
||||
|
||||
return toHexFormat(equipCertBytesHash);
|
||||
}
|
||||
|
||||
/**
|
||||
* generate hash from cer
|
||||
*
|
||||
* @param certificateGradeOne X509Certificate
|
||||
* @return String
|
||||
*/
|
||||
public static String genHashFromCer(X509Certificate certificateGradeOne) {
|
||||
String equipCert = null;
|
||||
byte[] equipCertBytesHash = null;
|
||||
try {
|
||||
equipCert = readPemFormat(certificateGradeOne);
|
||||
equipCertBytesHash = getSHA256Str(equipCert);
|
||||
} catch (IOException e) {
|
||||
LOGGER.severe(Common.addTag("[PkiUtil] catch Exception: " + e.getMessage()));
|
||||
}
|
||||
if (equipCertBytesHash == null) {
|
||||
return "";
|
||||
}
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for (byte byteData : equipCertBytesHash) {
|
||||
sb.append(String.format(Locale.ROOT, "%02x", byteData));
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* read pem format
|
||||
*
|
||||
* @param certificate Certificate
|
||||
* @return String result
|
||||
* @throws IOException e
|
||||
*/
|
||||
public static String readPemFormat(Certificate certificate) throws IOException {
|
||||
StringWriter writer = new StringWriter();
|
||||
PemWriter pemWriter = new PemWriter(writer);
|
||||
if (certificate == null) {
|
||||
LOGGER.severe(Common.addTag("[PkiUtil] the input parameter certificate is null, please check"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
try {
|
||||
pemWriter.writeObject(new PemObject("CERTIFICATE", certificate.getEncoded()));
|
||||
} catch (IOException | CertificateEncodingException e) {
|
||||
LOGGER.severe(
|
||||
Common.addTag("[PkiUtil] catch IOException or CertificateEncodingException in getPermFormat: "
|
||||
+ e.getMessage()));
|
||||
} finally {
|
||||
pemWriter.flush();
|
||||
pemWriter.close();
|
||||
}
|
||||
|
||||
return writer.toString();
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue