!27594 add code of flclient for device decoupling

Merge pull request !27594 from zhoushan33/flclient1213_decoupling
This commit is contained in:
i-robot 2021-12-16 11:01:06 +00:00 committed by Gitee
commit 6331c01405
27 changed files with 3683 additions and 917 deletions

View File

@ -0,0 +1,12 @@
package com.mindspore.flclient;
/**
* The cpu bind mod.
*
* @since 2021-06-30
*/
public enum BindMode {
NOT_BINDING_CORE,
BIND_LARGE_CORE,
BIND_MIDDLE_CORE
}

View File

@ -25,33 +25,44 @@ import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.cipher.AESEncrypt;
import com.mindspore.flclient.cipher.BaseUtil;
import com.mindspore.flclient.cipher.CertVerify;
import com.mindspore.flclient.cipher.ClientListReq;
import com.mindspore.flclient.cipher.KEYAgreement;
import com.mindspore.flclient.cipher.Masking;
import com.mindspore.flclient.cipher.ReconstructSecretReq;
import com.mindspore.flclient.cipher.ShareSecrets;
import com.mindspore.flclient.cipher.SignAndVerify;
import com.mindspore.flclient.cipher.struct.ClientPublicKey;
import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
import com.mindspore.flclient.cipher.struct.EncryptShare;
import com.mindspore.flclient.cipher.struct.NewArray;
import com.mindspore.flclient.cipher.struct.ShareSecret;
import com.mindspore.flclient.pki.PkiUtil;
import mindspore.schema.ClientShare;
import mindspore.schema.GetExchangeKeys;
import mindspore.schema.GetShareSecrets;
import mindspore.schema.RequestAllClientListSign;
import mindspore.schema.RequestExchangeKeys;
import mindspore.schema.RequestShareSecrets;
import mindspore.schema.ResponseClientListSign;
import mindspore.schema.ResponseCode;
import mindspore.schema.ResponseExchangeKeys;
import mindspore.schema.ResponseShareSecrets;
import mindspore.schema.ReturnAllClientListSign;
import mindspore.schema.ReturnExchangeKeys;
import mindspore.schema.ReturnShareSecrets;
import mindspore.schema.SendClientListSign;
import java.io.IOException;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.security.cert.CertificateEncodingException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
@ -93,6 +104,7 @@ public class CipherClient {
private ClientListReq clientListReq = new ClientListReq();
private ReconstructSecretReq reconstructSecretReq = new ReconstructSecretReq();
private int retCode;
private Map<String, X509Certificate[]> certificateList = new HashMap<String, X509Certificate[]>();
/**
* Construct function of cipherClient
@ -423,12 +435,6 @@ public class CipherClient {
" please check!"));
return FLClientStatus.FAILED;
}
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
byte[] cPK = cKey.get(0);
byte[] sPK = sKey.get(0);
int cpk = RequestExchangeKeys.createCPkVector(fbBuilder, cPK);
int spk = RequestExchangeKeys.createSPkVector(fbBuilder, sPK);
byte[] indIv = new byte[I_VEC_LEN];
byte[] pwIv = new byte[I_VEC_LEN];
byte[] thisPwSalt = new byte[SALT_SIZE];
@ -439,15 +445,24 @@ public class CipherClient {
this.individualIv = indIv;
this.pwIVec = pwIv;
this.pwSalt = thisPwSalt;
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
int indIvFbs = RequestExchangeKeys.createIndIvVector(fbBuilder, indIv);
int pwIvFbs = RequestExchangeKeys.createPwIvVector(fbBuilder, pwIv);
int pwSaltFbs = RequestExchangeKeys.createPwSaltVector(fbBuilder, thisPwSalt);
// for pkiVerify mode
int exchangeKeysRoot;
byte[] cPK = cKey.get(0);
byte[] sPK = sKey.get(0);
int cpk = RequestExchangeKeys.createCPkVector(fbBuilder, cPK);
int spk = RequestExchangeKeys.createSPkVector(fbBuilder, sPK);
int id = fbBuilder.createString(localFLParameter.getFlID());
Date date = new Date();
long timestamp = date.getTime();
String dateTime = String.valueOf(timestamp);
int time = fbBuilder.createString(dateTime);
String clientID = flParameter.getClientID();
// start build
RequestExchangeKeys.startRequestExchangeKeys(fbBuilder);
@ -459,7 +474,24 @@ public class CipherClient {
RequestExchangeKeys.addIndIv(fbBuilder, indIvFbs);
RequestExchangeKeys.addPwIv(fbBuilder, pwIvFbs);
RequestExchangeKeys.addPwSalt(fbBuilder, pwSaltFbs);
int exchangeKeysRoot = RequestExchangeKeys.endRequestExchangeKeys(fbBuilder);
if (flParameter.isPkiVerify()) {
// waiting for certificates to take effect
int waitTakeEffectTime = 5000;
Common.sleep(waitTakeEffectTime);
int nSize = 2; // exchange equipment certificate and service equipment
String[] pemCertificateChains = transformX509ArrayToPemArray(CertVerify.getX509CertificateChain(clientID));
int[] pemList = new int[nSize];
for (int i = 0; i < nSize; i++) {
pemList[i] = fbBuilder.createString(pemCertificateChains[i]);
}
int certificatesInt = RequestExchangeKeys.createCertificateChainVector(fbBuilder, pemList);
byte[] signature = signPkAndTime(clientID, cPK, sPK, dateTime, iteration);
int signed = RequestExchangeKeys.createSignatureVector(fbBuilder, signature);
RequestExchangeKeys.addSignature(fbBuilder, signed);
RequestExchangeKeys.addCertificateChain(fbBuilder, certificatesInt);
}
exchangeKeysRoot = RequestExchangeKeys.endRequestExchangeKeys(fbBuilder);
fbBuilder.finish(exchangeKeysRoot);
byte[] msg = fbBuilder.sizedByteArray();
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
@ -472,6 +504,7 @@ public class CipherClient {
"request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
retCode = ResponseCode.OutOfTime;
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);
@ -518,12 +551,30 @@ public class CipherClient {
String dateTime = String.valueOf(timestamp);
int time = fbBuilder.createString(dateTime);
// start build
GetExchangeKeys.startGetExchangeKeys(fbBuilder);
GetExchangeKeys.addFlId(fbBuilder, id);
GetExchangeKeys.addIteration(fbBuilder, iteration);
GetExchangeKeys.addTimestamp(fbBuilder, time);
int getExchangeKeysRoot = GetExchangeKeys.endGetExchangeKeys(fbBuilder);
int getExchangeKeysRoot;
byte[] signature = signTimeAndIter(dateTime, iteration);
if (signature == null) {
LOGGER.severe(Common.addTag("[getExchangeKeys] get signature is null!"));
return FLClientStatus.FAILED;
}
if (signature.length > 0) {
int signed = GetExchangeKeys.createSignatureVector(fbBuilder, signature);
// start build
GetExchangeKeys.startGetExchangeKeys(fbBuilder);
GetExchangeKeys.addFlId(fbBuilder, id);
GetExchangeKeys.addIteration(fbBuilder, iteration);
GetExchangeKeys.addTimestamp(fbBuilder, time);
GetExchangeKeys.addSignature(fbBuilder, signed);
getExchangeKeysRoot = GetExchangeKeys.endGetExchangeKeys(fbBuilder);
} else {
// start build
GetExchangeKeys.startGetExchangeKeys(fbBuilder);
GetExchangeKeys.addFlId(fbBuilder, id);
GetExchangeKeys.addIteration(fbBuilder, iteration);
GetExchangeKeys.addTimestamp(fbBuilder, time);
getExchangeKeysRoot = GetExchangeKeys.endGetExchangeKeys(fbBuilder);
}
fbBuilder.finish(getExchangeKeysRoot);
byte[] msg = fbBuilder.sizedByteArray();
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
@ -535,6 +586,7 @@ public class CipherClient {
"request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
retCode = ResponseCode.OutOfTime;
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);
@ -558,22 +610,44 @@ public class CipherClient {
clientPublicKeyList.clear();
u1ClientList.clear();
int length = bufData.remotePublickeysLength();
for (int i = 0; i < length; i++) {
ByteBuffer bufCpk = bufData.remotePublickeys(i).cPkAsByteBuffer();
ByteBuffer bufSpk = bufData.remotePublickeys(i).sPkAsByteBuffer();
int sizeCpk = bufData.remotePublickeys(i).cPkLength();
int sizeSpk = bufData.remotePublickeys(i).sPkLength();
byte[] bufCpkList = byteBufferToList(bufCpk, sizeCpk);
byte[] bufSpkList = byteBufferToList(bufSpk, sizeSpk);
// copy bufCpkList and bufSpkList
byte[] cPkByte = bufCpkList.clone();
byte[] sPkByte = bufSpkList.clone();
// check signature
boolean isPkiVerify = flParameter.isPkiVerify();
if (isPkiVerify) {
FLClientStatus checkResult = checkSignature(bufData, i, cPkByte, sPkByte);
if (checkResult == FLClientStatus.FAILED) {
return FLClientStatus.FAILED;
}
}
ClientPublicKey publicKey = new ClientPublicKey();
String srcFlId = bufData.remotePublickeys(i).flId();
publicKey.setFlID(srcFlId);
ByteBuffer bufCpk = bufData.remotePublickeys(i).cPkAsByteBuffer();
int sizeCpk = bufData.remotePublickeys(i).cPkLength();
ByteBuffer bufSpk = bufData.remotePublickeys(i).sPkAsByteBuffer();
int sizeSpk = bufData.remotePublickeys(i).sPkLength();
ByteBuffer bufPwIv = bufData.remotePublickeys(i).pwIvAsByteBuffer();
int sizePwIv = bufData.remotePublickeys(i).pwIvLength();
ByteBuffer bufPwSalt = bufData.remotePublickeys(i).pwSaltAsByteBuffer();
int sizePwSalt = bufData.remotePublickeys(i).pwSaltLength();
publicKey.setCPK(byteToArray(bufCpk, sizeCpk));
publicKey.setSPK(byteToArray(bufSpk, sizeSpk));
publicKey.setPwIv(byteToArray(bufPwIv, sizePwIv));
publicKey.setPwSalt(byteToArray(bufPwSalt, sizePwSalt));
NewArray<byte[]> bufCpkArray = new NewArray<>();
bufCpkArray.setSize(sizeCpk);
bufCpkArray.setArray(bufCpkList);
NewArray<byte[]> bufSpkArray = new NewArray<>();
bufSpkArray.setSize(sizeSpk);
bufSpkArray.setArray(bufSpkList);
publicKey.setCPK(bufCpkArray);
publicKey.setSPK(bufSpkArray);
clientPublicKeyList.put(srcFlId, publicKey);
u1ClientList.add(srcFlId);
}
@ -598,6 +672,86 @@ public class CipherClient {
}
}
private FLClientStatus checkSignature(ReturnExchangeKeys bufData, int dataIndex, byte[] cPkByte, byte[] sPkByte) {
ByteBuffer signature = bufData.remotePublickeys(dataIndex).signatureAsByteBuffer();
byte[] sigByte = new byte[signature.remaining()];
signature.get(sigByte);
int certifyNum = bufData.remotePublickeys(dataIndex).certificateChainLength();
String[] pemCerts = new String[certifyNum];
for (int certIndex = 0; certIndex < certifyNum; certIndex++) {
pemCerts[certIndex] = bufData.remotePublickeys(dataIndex).certificateChain(certIndex);
}
X509Certificate[] x509Certificates = CertVerify.transformPemArrayToX509Array(pemCerts);
if (x509Certificates.length < 2) {
LOGGER.severe(Common.addTag("the length of x509Certificates is not valid, should be >= 2"));
return FLClientStatus.FAILED;
}
String certificateHash = PkiUtil.genHashFromCer(x509Certificates[1]);
LOGGER.info(Common.addTag("Get certificate hash success!"));
// check srcId
String srcFlId = bufData.remotePublickeys(dataIndex).flId();
if (certificateHash.equals(srcFlId)) {
LOGGER.info(Common.addTag("Check flID success and source flID is:" + srcFlId));
} else {
LOGGER.severe(Common.addTag("Check flID failed!" + "source flID: " + srcFlId + "Hash ID from certificate:" +
" " + certificateHash.equals(srcFlId)));
return FLClientStatus.FAILED;
}
certificateList.put(srcFlId, x509Certificates);
String timestamp = bufData.remotePublickeys(dataIndex).timestamp();
String clientID = flParameter.getClientID();
if (!verifySignature(clientID, x509Certificates, sigByte, cPkByte, sPkByte, timestamp, iteration)) {
LOGGER.info(Common.addTag("[PairWiseMask] FlID: " + srcFlId +
", signature authentication failed"));
return FLClientStatus.FAILED;
} else {
LOGGER.info(Common.addTag("[PairWiseMask] Verify signature success!"));
}
// check iteration and timestamp
int remoteIter = bufData.iteration();
FLClientStatus iterTimeCheck = checkIterAndTimestamp(remoteIter, timestamp);
if (iterTimeCheck == FLClientStatus.FAILED) {
return FLClientStatus.FAILED;
}
return FLClientStatus.SUCCESS;
}
private FLClientStatus checkIterAndTimestamp(int remoteIter, String timestamp) {
if (remoteIter != iteration) {
LOGGER.severe(Common.addTag("[PairWiseMask] iteration check failed. Remote iteration of client: " + "is "
+ remoteIter + ", which is not consistent with current iteration:" + iteration));
return FLClientStatus.FAILED;
}
Date date = new Date();
long currentTimeStamp = date.getTime();
if (timestamp == null) {
LOGGER.severe(Common.addTag("[PairWiseMask] Received timeStamp is null,please check it!"));
return FLClientStatus.FAILED;
}
long remoteTimeStamp = Long.parseLong(timestamp);
long validIterInterval = flParameter.getValidInterval();
if (currentTimeStamp - remoteTimeStamp > validIterInterval || currentTimeStamp < remoteTimeStamp) {
LOGGER.severe(Common.addTag("[PairWiseMask] timeStamp check failed! The difference between" +
" remote timestamp and current timestamp is beyond valid iteration interval!"));
return FLClientStatus.FAILED;
}
return FLClientStatus.SUCCESS;
}
private byte[] byteBufferToList(ByteBuffer buf, int size) {
byte[] array = new byte[size];
for (int i = 0; i < size; i++) {
byte word = buf.get();
array[i] = word;
}
return array;
}
private FLClientStatus requestShareSecrets() {
FLClientStatus status = genIndividualSecret();
if (status == FLClientStatus.FAILED) {
@ -642,25 +796,44 @@ public class CipherClient {
}
int encryptedSharesFbs = RequestShareSecrets.createEncryptedSharesVector(fbBuilder, add);
// start build
RequestShareSecrets.startRequestShareSecrets(fbBuilder);
RequestShareSecrets.addFlId(fbBuilder, id);
RequestShareSecrets.addEncryptedShares(fbBuilder, encryptedSharesFbs);
RequestShareSecrets.addIteration(fbBuilder, iteration);
RequestShareSecrets.addTimestamp(fbBuilder, time);
int requestShareSecretsRoot = RequestShareSecrets.endRequestShareSecrets(fbBuilder);
int requestShareSecretsRoot;
byte[] signature = signTimeAndIter(dateTime, iteration);
if (signature == null) {
LOGGER.severe(Common.addTag("[PairWiseMask] get signature is null!"));
return FLClientStatus.FAILED;
}
if (signature.length > 0) {
int signed = RequestShareSecrets.createSignatureVector(fbBuilder, signature);
// start build
RequestShareSecrets.startRequestShareSecrets(fbBuilder);
RequestShareSecrets.addFlId(fbBuilder, id);
RequestShareSecrets.addEncryptedShares(fbBuilder, encryptedSharesFbs);
RequestShareSecrets.addIteration(fbBuilder, iteration);
RequestShareSecrets.addTimestamp(fbBuilder, time);
RequestShareSecrets.addSignature(fbBuilder, signed);
requestShareSecretsRoot = RequestShareSecrets.endRequestShareSecrets(fbBuilder);
} else {
// start build
RequestShareSecrets.startRequestShareSecrets(fbBuilder);
RequestShareSecrets.addFlId(fbBuilder, id);
RequestShareSecrets.addEncryptedShares(fbBuilder, encryptedSharesFbs);
RequestShareSecrets.addIteration(fbBuilder, iteration);
RequestShareSecrets.addTimestamp(fbBuilder, time);
requestShareSecretsRoot = RequestShareSecrets.endRequestShareSecrets(fbBuilder);
}
fbBuilder.finish(requestShareSecretsRoot);
byte[] msg = fbBuilder.sizedByteArray();
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
try {
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
byte[] responseData = flCommunication.syncRequest(url + "/shareSecrets", msg);
if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[requestShareSecrets] the server is not ready now, need wait some time" +
" " +
"and request again"));
" and request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
retCode = ResponseCode.OutOfTime;
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);
@ -708,11 +881,27 @@ public class CipherClient {
String dateTime = String.valueOf(timestamp);
int time = fbBuilder.createString(dateTime);
GetShareSecrets.startGetShareSecrets(fbBuilder);
GetShareSecrets.addFlId(fbBuilder, id);
GetShareSecrets.addIteration(fbBuilder, iteration);
GetShareSecrets.addTimestamp(fbBuilder, time);
int getShareSecrets = GetShareSecrets.endGetShareSecrets(fbBuilder);
int getShareSecrets;
byte[] signature = signTimeAndIter(dateTime, iteration);
if (signature == null) {
LOGGER.severe(Common.addTag("[getShareSecrets] get signature is null!"));
return FLClientStatus.FAILED;
}
if (signature.length > 0) {
int signed = GetShareSecrets.createSignatureVector(fbBuilder, signature);
GetShareSecrets.startGetShareSecrets(fbBuilder);
GetShareSecrets.addFlId(fbBuilder, id);
GetShareSecrets.addIteration(fbBuilder, iteration);
GetShareSecrets.addTimestamp(fbBuilder, time);
GetShareSecrets.addSignature(fbBuilder, signed);
getShareSecrets = GetShareSecrets.endGetShareSecrets(fbBuilder);
} else {
GetShareSecrets.startGetShareSecrets(fbBuilder);
GetShareSecrets.addFlId(fbBuilder, id);
GetShareSecrets.addIteration(fbBuilder, iteration);
GetShareSecrets.addTimestamp(fbBuilder, time);
getShareSecrets = GetShareSecrets.endGetShareSecrets(fbBuilder);
}
fbBuilder.finish(getShareSecrets);
byte[] msg = fbBuilder.sizedByteArray();
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
@ -724,6 +913,7 @@ public class CipherClient {
"request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
retCode = ResponseCode.OutOfTime;
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);
@ -864,6 +1054,19 @@ public class CipherClient {
}
retCode = clientListReq.getRetCode();
// clientListCheck
if (flParameter.isPkiVerify()) {
LOGGER.info(Common.addTag("[PairWiseMask] The mode is pkiVerify mode, start clientList check ..."));
curStatus = clientListCheck();
while (curStatus == FLClientStatus.WAIT) {
Common.sleep(SLEEP_TIME);
curStatus = clientListCheck();
}
if (curStatus != FLClientStatus.SUCCESS) {
return curStatus;
}
}
// SendReconstructSecret
curStatus = reconstructSecretReq.sendReconstructSecret(decryptShareSecretsList, u3ClientList, iteration);
while (curStatus == FLClientStatus.WAIT) {
@ -876,4 +1079,340 @@ public class CipherClient {
retCode = reconstructSecretReq.getRetCode();
return curStatus;
}
private byte[] signPkAndTime(String clientID, byte[] cPK, byte[] sPK, String time, int iterNum) {
// concatenate cPK, sPK and time
byte[] concatData = concatenateData(cPK, sPK, time, iterNum);
LOGGER.info("concatenate data success!");
// signature
return SignAndVerify.signData(clientID, concatData);
}
private static byte[] concatenateData(byte[] cPK, byte[] sPK, String time, int iterNum) {
// concatenate cPK, sPK and time
if (time == null) {
LOGGER.severe(Common.addTag("[concatenateData] input time is null, please check!"));
throw new IllegalArgumentException();
}
byte[] byteTime = time.getBytes(StandardCharsets.UTF_8);
String iterString = String.valueOf(iterNum);
byte[] byteIter = iterString.getBytes(StandardCharsets.UTF_8);
int concatLength = cPK.length + sPK.length + byteTime.length + byteIter.length;
byte[] concatData = new byte[concatLength];
int offset = 0;
System.arraycopy(cPK, 0, concatData, offset, cPK.length);
offset += cPK.length;
System.arraycopy(sPK, 0, concatData, offset, sPK.length);
offset += sPK.length;
System.arraycopy(byteTime, 0, concatData, offset, byteTime.length);
offset += byteTime.length;
System.arraycopy(byteIter, 0, concatData, offset, byteIter.length);
return concatData;
}
private static byte[] concatenateIterAndTime(String time, int iterNum) {
// concatenate cPK, sPK and time
byte[] byteTime = time.getBytes(StandardCharsets.UTF_8);
String iterString = String.valueOf(iterNum);
byte[] byteIter = iterString.getBytes(StandardCharsets.UTF_8);
int concatLength = byteTime.length + byteIter.length;
byte[] concatData = new byte[concatLength];
int offset = 0;
System.arraycopy(byteTime, 0, concatData, offset, byteTime.length);
offset += byteTime.length;
System.arraycopy(byteIter, 0, concatData, offset, byteIter.length);
return concatData;
}
private static boolean verifySignature(String clientID, X509Certificate[] x509Certificates, byte[] signature,
byte[] cPK, byte[] sPK, String timestamp, int iteration) {
byte[] concatData = concatenateData(cPK, sPK, timestamp, iteration);
return SignAndVerify.verifySignatureByCert(clientID, x509Certificates, concatData, signature);
}
private FLClientStatus clientListCheck() {
LOGGER.info(Common.addTag("[PairWiseMask] ==================== ClientListCheck ======================"));
FLClientStatus curStatus;
// send signed clientList
curStatus = sendClientListSign();
while (curStatus == FLClientStatus.WAIT) {
Common.sleep(SLEEP_TIME);
curStatus = sendClientListSign();
}
if (curStatus != FLClientStatus.SUCCESS) {
return curStatus;
}
// get signed clientList
curStatus = getAllClientListSign();
while (curStatus == FLClientStatus.WAIT) {
Common.sleep(SLEEP_TIME);
curStatus = getAllClientListSign();
}
return curStatus;
}
private FLClientStatus sendClientListSign() {
LOGGER.info(Common.addTag("[PairWiseMask] ==============request flID: " +
localFLParameter.getFlID() + "=============="));
genDHKeyPairs();
List<String> clientList = u3ClientList;
int listSize = u3ClientList.size();
if (listSize == 0) {
LOGGER.severe("[Encrypt] u3List is empty, please check!");
return FLClientStatus.FAILED;
}
// send signature
byte[] clientListByte = transStringListToByte(clientList);
byte[] listHash = SignAndVerify.getSHA256(clientListByte);
String clientID = flParameter.getClientID();
byte[] signature = SignAndVerify.signData(clientID, listHash);
if (signature == null) {
LOGGER.severe(Common.addTag("[sendClientListSign] the returned signature is null"));
return FLClientStatus.FAILED;
}
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
int signed = RequestExchangeKeys.createSignatureVector(fbBuilder, signature);
int sendClientListRoot;
Date date = new Date();
long timestamp = date.getTime();
String dateTime = String.valueOf(timestamp);
byte[] reqSign = signTimeAndIter(dateTime, iteration);
String flID = localFLParameter.getFlID();
int id = fbBuilder.createString(flID);
int time = fbBuilder.createString(dateTime);
if (signature.length > 0) {
int reqSigned = SendClientListSign.createSignatureVector(fbBuilder, reqSign);
sendClientListRoot = SendClientListSign.createSendClientListSign(fbBuilder, id, iteration, time, signed,
reqSigned);
} else {
sendClientListRoot = SendClientListSign.createSendClientListSign(fbBuilder, id, iteration, time, signed, 0);
}
fbBuilder.finish(sendClientListRoot);
byte[] msg = fbBuilder.sizedByteArray();
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
try {
byte[] responseData = flCommunication.syncRequest(url + "/pushListSign", msg);
if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[sendClientListSign] the server is not ready now, need wait some time and " +
"request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
retCode = ResponseCode.OutOfTime;
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);
ResponseClientListSign responseClientListSign =
ResponseClientListSign.getRootAsResponseClientListSign(buffer);
return judgeRequestClientList(responseClientListSign);
} catch (IOException e) {
e.printStackTrace();
return FLClientStatus.FAILED;
}
}
private FLClientStatus judgeRequestClientList(ResponseClientListSign bufData) {
retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestClientListSign**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
switch (retCode) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[PairWiseMask] RequestClientListSign success"));
return FLClientStatus.SUCCESS;
case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] RequestClientListSign out of time: need wait and request " +
"startFLJob again"));
setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
case (ResponseCode.SystemError):
LOGGER.info(Common.addTag("[PairWiseMask] catch RequestError or SystemError in RequestClientListSign"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in RequestClientListSign" +
" is invalid: " + retCode));
return FLClientStatus.FAILED;
}
}
private FLClientStatus getAllClientListSign() {
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
int id = fbBuilder.createString(localFLParameter.getFlID());
Date date = new Date();
long timestamp = date.getTime();
String dateTime = String.valueOf(timestamp);
int time = fbBuilder.createString(dateTime);
int requestAllClientListSign;
byte[] signature = signTimeAndIter(dateTime, iteration);
if (signature.length > 0) {
int signed = RequestAllClientListSign.createSignatureVector(fbBuilder, signature);
requestAllClientListSign = RequestAllClientListSign.createRequestAllClientListSign(fbBuilder, id,
iteration, time, signed);
} else {
requestAllClientListSign = RequestAllClientListSign.createRequestAllClientListSign(fbBuilder, id,
iteration, time, 0);
}
fbBuilder.finish(requestAllClientListSign);
byte[] msg = fbBuilder.sizedByteArray();
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
try {
byte[] responseData = flCommunication.syncRequest(url + "/getListSign", msg);
if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[getAllClientListSign] the server is not ready now, need wait some time " +
"and request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
retCode = ResponseCode.OutOfTime;
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);
ReturnAllClientListSign returnAllClientList =
ReturnAllClientListSign.getRootAsReturnAllClientListSign(buffer);
return judgeAllClientList(returnAllClientList);
} catch (IOException e) {
e.printStackTrace();
return FLClientStatus.FAILED;
}
}
private FLClientStatus judgeAllClientList(ReturnAllClientListSign bufData) {
retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetAllClientsList**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
switch (retCode) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[PairWiseMask] GetAllClientList success"));
int length = bufData.clientListSignLength();
String clientID = flParameter.getClientID();
String localFlID = localFLParameter.getFlID();
byte[] localClientList = transStringListToByte(u3ClientList);
byte[] localListHash = SignAndVerify.getSHA256(localClientList);
for (int i = 0; i < length; i++) {
// verify signature
ByteBuffer signature = bufData.clientListSign(i).signatureAsByteBuffer();
byte[] sigByte = new byte[signature.remaining()];
signature.get(sigByte);
if (bufData.clientListSign(i).flId() == null) {
LOGGER.severe(Common.addTag("[PairWiseMask] get flID failed!"));
return FLClientStatus.FAILED;
}
String srcFlId = bufData.clientListSign(i).flId();
X509Certificate[] remoteCertificates = certificateList.get(srcFlId);
if (localFlID.equals(srcFlId)) {
continue;
} // Do not verify itself
if (!SignAndVerify.verifySignatureByCert(clientID, remoteCertificates, localListHash, sigByte)) {
LOGGER.info(Common.addTag("[PairWiseMask] FlID: " + srcFlId +
", signature authentication failed"));
return FLClientStatus.FAILED;
}
}
return FLClientStatus.SUCCESS;
case (ResponseCode.SucNotReady):
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request " +
"GetAllClientsList again!"));
return FLClientStatus.WAIT;
case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] GetAllClientsList out of time: need wait and request " +
"startFLJob again"));
setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
case (ResponseCode.SystemError):
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetAllClientsList"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnAllClientList " +
"is invalid: " + retCode));
return FLClientStatus.FAILED;
}
}
private byte[] transStringListToByte(List<String> stringList) {
int byteNum = 0;
for (String value : stringList) {
byte[] stringByte = value.getBytes(StandardCharsets.UTF_8);
byteNum += stringByte.length;
}
byte[] concatData = new byte[byteNum];
int offset = 0;
for (String str : stringList) {
byte[] stringByte = str.getBytes(StandardCharsets.UTF_8);
System.arraycopy(stringByte, 0, concatData, offset, stringByte.length);
offset += stringByte.length;
}
return concatData;
}
/**
* Add signature on timestamp and iteration
*
* @param dateTime the timestamp of data
* @param iteration iteration number
* @return signed time and iteration
*/
public static byte[] signTimeAndIter(String dateTime, int iteration) {
// signature
FLParameter flParameter = FLParameter.getInstance();
String clientID = flParameter.getClientID();
boolean isPkiVerify = flParameter.isPkiVerify();
byte[] signature = new byte[0];
if (isPkiVerify) {
LOGGER.info(Common.addTag("ClientID is:" + clientID));
byte[] concatData = concatenateIterAndTime(dateTime, iteration);
signature = SignAndVerify.signData(clientID, concatData);
}
return signature;
}
private static String transformX509ToPem(X509Certificate x509Certificate) {
if (x509Certificate == null) {
LOGGER.severe(Common.addTag("[CertVerify] x509Certificate is null, please check!"));
return null;
}
String pemCert;
try {
byte[] derCert = x509Certificate.getEncoded();
pemCert = new String(Base64.getEncoder().encode(derCert));
} catch (CertificateEncodingException e) {
LOGGER.severe(Common.addTag("[CertVerify] catch Exception: " + e.getMessage()));
return null;
}
return pemCert;
}
private static String[] transformX509ArrayToPemArray(X509Certificate[] x509Certificates) {
if (x509Certificates == null || x509Certificates.length == 0) {
LOGGER.severe(Common.addTag("[CertVerify] certificateChains is null or empty, please check!"));
throw new IllegalArgumentException();
}
int nSize = x509Certificates.length;
String[] pemCerts = new String[nSize];
for (int i = 0; i < nSize; ++i) {
String pemCert = transformX509ToPem(x509Certificates[i]);
pemCerts[i] = pemCert;
}
return pemCerts;
}
}

View File

@ -16,6 +16,14 @@
package com.mindspore.flclient;
import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.Client;
import com.mindspore.flclient.model.ClientManager;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.Status;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.ResponseCode;
import org.bouncycastle.crypto.BlockCipher;
import org.bouncycastle.crypto.engines.AESEngine;
import org.bouncycastle.crypto.prng.SP800SecureRandomBuilder;
@ -33,6 +41,9 @@ import java.util.logging.Logger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
/**
* Define basic global methods used in federated learning task.
*
@ -44,6 +55,9 @@ public class Common {
*/
public static final String LOG_TITLE = "<FLClient> ";
public static final String LOG_DEPRECATED = "This method will be deprecated in the next version, it is " +
"recommended to " +
"use the latest method according to the use cases in the official website tutorial";
/**
* The list of trust flName.
*/
@ -63,8 +77,16 @@ public class Common {
* The tag when server is not ready.
*/
public static final String JOB_NOT_AVAILABLE = "The server's training job is disabled or finished.";
/**
* use to stop job.
*/
private static final Object STOP_OBJECT = new Object();
private static final Logger LOGGER = Logger.getLogger(Common.class.toString());
private static List<String> envTrustList = new ArrayList<>(Arrays.asList("x86", "android"));
private static SecureRandom secureRandom;
private static int iteration;
private static boolean isHttps;
/**
* Generate the URL for device-sever interaction
@ -107,6 +129,7 @@ public class Common {
throw new IllegalArgumentException();
}
String tag = domainName.split("//")[0] + "//";
setIsHttps(domainName.split("//")[0].split(":")[0]);
Random rand = new Random();
int randomNum = rand.nextInt(100000) % serverNum + port;
url = tag + ip + ":" + String.valueOf(randomNum);
@ -167,6 +190,17 @@ public class Common {
return (FL_NAME_TRUST_LIST.contains(flName));
}
/**
* Check if the deploy environment set by user is in the trust list.
*
* @param env the deploy environment for federated learning task set by user.
* @return boolean value, true indicates the deploy environment set by user is valid, false indicates the deploy
* environment set by user is not valid.
*/
public static boolean checkEnv(String env) {
return (envTrustList.contains(env));
}
/**
* Check whether the sslProtocol set by user is in the trust list.
*
@ -184,11 +218,23 @@ public class Common {
* @param millis the waiting time (ms).
*/
public static void sleep(long millis) {
try {
Thread.sleep(millis); // 1000 milliseconds is one second.
} catch (InterruptedException ex) {
LOGGER.severe(addTag("[sleep] catch InterruptedException: " + ex.getMessage()));
Thread.currentThread().interrupt();
if (millis > 0) {
try {
synchronized (STOP_OBJECT) {
STOP_OBJECT.wait(millis); // 1000 milliseconds is one second.
}
} catch (InterruptedException ex) {
LOGGER.severe(addTag("[sleep] catch InterruptedException: " + ex.getMessage()));
}
}
}
/**
* Use to stop the wait method of a Object.
*/
public static void notifyObject() {
synchronized (STOP_OBJECT) {
STOP_OBJECT.notify();
}
}
@ -261,10 +307,20 @@ public class Common {
String messageStr = new String(message);
if (messageStr.contains(SAFE_MOD)) {
LOGGER.info(Common.addTag("[isSeverReady] " + SAFE_MOD + ", need wait some time and request again"));
if (messageStr.split(":").length == 2) {
iteration = Integer.parseInt(messageStr.split(":")[1]);
} else {
LOGGER.info(Common.addTag("[isSeverReady] the server does not return the current iteration."));
}
return false;
} else if (messageStr.contains(JOB_NOT_AVAILABLE)) {
LOGGER.info(Common.addTag("[isSeverReady] " + JOB_NOT_AVAILABLE + ", need wait some time and request " +
"again"));
if (messageStr.split(":").length == 2) {
iteration = Integer.parseInt(messageStr.split(":")[1]);
} else {
LOGGER.info(Common.addTag("[isSeverReady] the server does not return the current iteration."));
}
return false;
} else {
return true;
@ -307,7 +363,7 @@ public class Common {
* Check whether the path set by user exists.
*
* @param path the path set by user.
* @return boolean value, true indicates the path is exist, false indicates the path is not exist
* @return boolean value, true indicates the path is exist, false indicates the path does not exist
*/
public static boolean checkPath(String path) {
if (path == null) {
@ -323,7 +379,7 @@ public class Common {
LOGGER.info(addTag("[check path " + i + "] " + paths[i]));
File file = new File(paths[i]);
if (!file.exists()) {
LOGGER.severe(Common.addTag("[checkPath] the path is not exist, please check"));
LOGGER.severe(Common.addTag("[checkPath] the path does not exist, please check"));
return false;
}
}
@ -415,4 +471,131 @@ public class Common {
throw new IllegalArgumentException();
}
}
/**
* Record the current iteration when server is not ready.
*
* @return int value, the current iteration when server is not ready.
*/
public static int getIteration() {
return iteration;
}
/**
* Determine whether to conduct https communication according to the domain name set by user.
*
* @return boolean value, true means conducting https communication, false means conducting http communication.
*/
public static boolean isHttps() {
return isHttps;
}
public static void setIsHttps(String tag) {
if ("https".equals(tag)) {
LOGGER.info(Common.addTag("conducting https communication"));
Common.isHttps = true;
} else if ("http".equals(tag)) {
LOGGER.info(Common.addTag("conducting http communication"));
Common.isHttps = false;
} else {
LOGGER.info(Common.addTag("The domain header set by the user is incorrect, please check"));
throw new IllegalArgumentException();
}
}
/**
* Initialization session.
*
* @return the status code in client.
*/
public static FLClientStatus initSession(String modelPath) {
FLParameter flParameter = FLParameter.getInstance();
LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
if (Common.checkFLName(flParameter.getFlName())) {
return deprecatedInitSession();
}
Status tag;
LOGGER.info(Common.addTag("==========Loading model, " + modelPath + " Create " +
" Session============="));
Client client = ClientManager.getClient(flParameter.getFlName());
tag = client.initSessionAndInputs(modelPath, localFLParameter.getMsConfig());
if (!Status.SUCCESS.equals(tag)) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return " +
"is -1"));
return FLClientStatus.FAILED;
}
return FLClientStatus.SUCCESS;
}
/**
* Free session.
*/
protected static void freeSession() {
FLParameter flParameter = FLParameter.getInstance();
if (Common.checkFLName(flParameter.getFlName())) {
deprecatedFreeSession();
} else {
LOGGER.info(Common.addTag("===========free session============="));
Client client = ClientManager.getClient(flParameter.getFlName());
client.free();
}
}
/**
* Initialization session.
*
* @return the status code in client.
*/
private static FLClientStatus deprecatedInitSession() {
FLParameter flParameter = FLParameter.getInstance();
int tag = 0;
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " +
"Train Session============="));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return " +
"is -1"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " " +
"Create inference Session============="));
AlInferBert alInferBert = AlInferBert.getInstance();
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " +
"Train Session============="));
TrainLenet trainLenet = TrainLenet.getInstance();
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
}
if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is " +
"-1"));
return FLClientStatus.FAILED;
}
return FLClientStatus.SUCCESS;
}
/**
* Free session.
*/
private static void deprecatedFreeSession() {
FLParameter flParameter = FLParameter.getInstance();
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("===========free train session============="));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
SessionUtil.free(alTrainBert.getTrainSession());
if (!flParameter.getTestDataset().equals("null")) {
LOGGER.info(Common.addTag("===========free inference session============="));
AlInferBert alInferBert = AlInferBert.getInstance();
SessionUtil.free(alInferBert.getTrainSession());
}
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("===========free session============="));
TrainLenet trainLenet = TrainLenet.getInstance();
SessionUtil.free(trainLenet.getTrainSession());
}
}
}

View File

@ -17,6 +17,8 @@
package com.mindspore.flclient;
import static com.mindspore.flclient.FLParameter.TIME_OUT;
import static com.mindspore.flclient.LocalFLParameter.ANDROID;
import static com.mindspore.flclient.LocalFLParameter.X86;
import okhttp3.Call;
import okhttp3.Callback;
@ -49,12 +51,16 @@ import javax.net.ssl.X509TrustManager;
*/
public class FLCommunication implements IFLCommunication {
private static int timeOut;
private static boolean ifCertificateVerify = false;
private static String sslProtocol;
private static String env;
private static SSLSocketFactory sslSocketFactory;
private static X509TrustManager x509TrustManager;
private static final MediaType MEDIA_TYPE_JSON = MediaType.parse("applicatiom/json;charset=utf-8");
private static final Logger LOGGER = Logger.getLogger(FLCommunication.class.toString());
private static volatile FLCommunication communication;
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private OkHttpClient client;
private FLCommunication() {
@ -63,7 +69,14 @@ public class FLCommunication implements IFLCommunication {
} else {
timeOut = TIME_OUT;
}
ifCertificateVerify = flParameter.isUseSSL();
sslProtocol = flParameter.getSslProtocol();
if (!Common.checkFLName(flParameter.getFlName())) {
env = flParameter.getDeployEnv();
if (ANDROID.equals(env)) {
sslSocketFactory = flParameter.getSslSocketFactory();
x509TrustManager = flParameter.getX509TrustManager();
}
}
client = getOkHttpClient();
}
@ -85,33 +98,29 @@ public class FLCommunication implements IFLCommunication {
}
};
final TrustManager[] trustAllCerts = new TrustManager[]{trustManager};
try {
LOGGER.info(Common.addTag("the set timeOut in OkHttpClient: " + timeOut));
OkHttpClient.Builder builder = new OkHttpClient.Builder();
builder.connectTimeout(timeOut, TimeUnit.SECONDS);
builder.writeTimeout(timeOut, TimeUnit.SECONDS);
builder.readTimeout(3 * timeOut, TimeUnit.SECONDS);
if (ifCertificateVerify) {
builder.sslSocketFactory(SSLSocketFactoryTools.getInstance().getmSslSocketFactory(),
SSLSocketFactoryTools.getInstance().getmTrustManager());
builder.hostnameVerifier(SSLSocketFactoryTools.getInstance().getHostnameVerifier());
} else {
final SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(null, trustAllCerts, Common.getSecureRandom());
final SSLSocketFactory sslFactory = sslContext.getSocketFactory();
builder.sslSocketFactory(sslFactory, trustManager);
LOGGER.info(Common.addTag("the set timeOut in OkHttpClient: " + timeOut));
OkHttpClient.Builder builder = new OkHttpClient.Builder();
builder.connectTimeout(timeOut, TimeUnit.SECONDS);
builder.writeTimeout(timeOut, TimeUnit.SECONDS);
builder.readTimeout(3 * timeOut, TimeUnit.SECONDS);
if (Common.isHttps()) {
if (ANDROID.equals(env)) {
builder.sslSocketFactory(sslSocketFactory, x509TrustManager);
builder.hostnameVerifier(new HostnameVerifier() {
@Override
public boolean verify(String arg0, SSLSession arg1) {
return true;
}
});
} else {
builder.sslSocketFactory(SSLSocketFactoryTools.getInstance().getmSslSocketFactory(),
SSLSocketFactoryTools.getInstance().getmTrustManager());
builder.hostnameVerifier(SSLSocketFactoryTools.getInstance().getHostnameVerifier());
}
return builder.build();
} catch (NoSuchAlgorithmException | KeyManagementException ex) {
LOGGER.severe(Common.addTag("[OkHttpClient] catch NoSuchAlgorithmException or KeyManagementException: " + ex.getMessage()));
throw new IllegalArgumentException(ex);
} else {
LOGGER.info(Common.addTag("conducting http communication, do not need SSLSocketFactoryTools"));
}
return builder.build();
}
/**

View File

@ -33,6 +33,7 @@ public class FLJobResultCallback implements IFLJobResultCallback {
* @param iterationSeq Iteration number
* @param resultCode Status Code
*/
@Override
public void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode) {
LOGGER.info(Common.addTag("[onFlJobIterationFinished] modelName: " + modelName + " iterationSeq: " +
iterationSeq + " resultCode: " + resultCode));
@ -45,6 +46,7 @@ public class FLJobResultCallback implements IFLJobResultCallback {
* @param iterationCount total Iteration numbers
* @param resultCode Status Code
*/
@Override
public void onFlJobFinished(String modelName, int iterationCount, int resultCode) {
LOGGER.info(Common.addTag("[onFlJobFinished] modelName: " + modelName + " iterationCount: " +
iterationCount + " resultCode: " + resultCode));

View File

@ -22,8 +22,16 @@ import static com.mindspore.flclient.LocalFLParameter.LENET;
import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.Client;
import com.mindspore.flclient.model.ClientManager;
import com.mindspore.flclient.model.CommonUtils;
import com.mindspore.flclient.model.RunType;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.Status;
import com.mindspore.flclient.model.TrainLenet;
import com.mindspore.flclient.pki.PkiBean;
import com.mindspore.flclient.pki.PkiUtil;
import com.mindspore.lite.MSTensor;
import mindspore.schema.CipherPublicParams;
import mindspore.schema.FLPlan;
@ -36,6 +44,7 @@ import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.logging.Logger;
@ -54,7 +63,7 @@ public class FLLiteClient {
private double dpNormClipAdapt = 0.05d;
private FLCommunication flCommunication;
private FLClientStatus status;
private int retCode;
private int retCode = ResponseCode.RequestError;
private int iterations = 1;
private int epochs = 1;
private int batchSize = 16;
@ -68,12 +77,15 @@ public class FLLiteClient {
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private SecureProtocol secureProtocol = new SecureProtocol();
private String nextRequestTime;
private Client client;
private Map<String, float[]> oldFeatureMap;
/**
* Defining a constructor of teh class FLLiteClient.
*/
public FLLiteClient() {
flCommunication = FLCommunication.getInstance();
client = ClientManager.getClient(flParameter.getFlName());
}
private int setGlobalParameters(ResponseFLJob flJob) {
@ -87,22 +99,15 @@ public class FLLiteClient {
batchSize = flPlan.miniBatch();
String serverMod = flPlan.serverMode();
localFLParameter.setServerMod(serverMod);
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <serverMod> from server: " + serverMod));
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for AlTrainBert: " + batchSize));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
alTrainBert.setBatchSize(batchSize);
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for TrainLenet: " + batchSize));
TrainLenet trainLenet = TrainLenet.getInstance();
trainLenet.setBatchSize(batchSize);
if (Common.checkFLName(flParameter.getFlName())) {
deprecatedSetBatchSize(batchSize);
} else {
LOGGER.severe(Common.addTag("[startFLJob] the ServerMod returned from server is not valid"));
return -1;
LOGGER.info(Common.addTag("[startFLJob] not set <batchSize> for client: " + batchSize));
}
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <iterations> from server: " + iterations));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <epochs> from server: " + epochs));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <batchSize> from server: " + batchSize));
LOGGER.info(Common.addTag("[startFLJob] the GlobalParameter <serverMod> from server: " + serverMod));
LOGGER.info(Common.addTag("[startFLJob] the GlobalParameter <iterations> from server: " + iterations));
LOGGER.info(Common.addTag("[startFLJob] the GlobalParameter <epochs> from server: " + epochs));
LOGGER.info(Common.addTag("[startFLJob] the GlobalParameter <batchSize> from server: " + batchSize));
CipherPublicParams cipherPublicParams = flPlan.cipher();
if (cipherPublicParams == null) {
LOGGER.severe(Common.addTag("[startFLJob] the cipherPublicParams returned from server is null"));
@ -232,7 +237,12 @@ public class FLLiteClient {
StartFLJob startFLJob = StartFLJob.getInstance();
Date date = new Date();
long time = date.getTime();
byte[] msg = startFLJob.getRequestStartFLJob(trainDataSize, iteration, time);
PkiBean pkiBean = null;
if (flParameter.isPkiVerify()) {
pkiBean = PkiUtil.genPkiBean(flParameter.getClientID(), time);
}
byte[] msg = startFLJob.getRequestStartFLJob(trainDataSize, iteration, time, pkiBean);
try {
long start = Common.startTime("single startFLJob");
LOGGER.info(Common.addTag("[startFLJob] the request message length: " + msg.length));
@ -243,6 +253,7 @@ public class FLLiteClient {
status = FLClientStatus.RESTART;
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
retCode = ResponseCode.OutOfTime;
return status;
}
LOGGER.info(Common.addTag("[startFLJob] the response message length: " + message.length));
@ -250,12 +261,9 @@ public class FLLiteClient {
ByteBuffer buffer = ByteBuffer.wrap(message);
ResponseFLJob responseDataBuf = ResponseFLJob.getRootAsResponseFLJob(buffer);
status = judgeStartFLJob(startFLJob, responseDataBuf);
retCode = responseDataBuf.retcode();
} catch (IOException e) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in StartFLJob: catch IOException: " +
e.getMessage()));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
failed("[startFLJob] unsolved error code in StartFLJob: catch IOException: " + e.getMessage(),
ResponseCode.RequestError);
}
return status;
}
@ -263,12 +271,13 @@ public class FLLiteClient {
private FLClientStatus judgeStartFLJob(StartFLJob startFLJob, ResponseFLJob responseDataBuf) {
iteration = responseDataBuf.iteration();
FLClientStatus response = startFLJob.doResponse(responseDataBuf);
retCode = startFLJob.getRetCode();
status = response;
switch (response) {
case SUCCESS:
LOGGER.info(Common.addTag("[startFLJob] startFLJob success"));
featureSize = startFLJob.getFeatureSize();
secureProtocol.setEncryptFeatureName(startFLJob.getEncryptFeatureName());
secureProtocol.setUpdateFeatureName(startFLJob.getUpdateFeatureName());
LOGGER.info(Common.addTag("[startFLJob] ***the feature size get in ResponseFLJob***: " + featureSize));
int tag = setGlobalParameters(responseDataBuf);
if (tag == -1) {
@ -297,38 +306,37 @@ public class FLLiteClient {
return status;
}
private FLClientStatus trainLoop() {
retCode = ResponseCode.SUCCEED;
status = Common.initSession(flParameter.getTrainModelPath());
if (status == FLClientStatus.FAILED) {
retCode = ResponseCode.RequestError;
return status;
}
retCode = ResponseCode.SUCCEED;
LOGGER.info(Common.addTag("[train] train in " + flParameter.getFlName()));
Status tag = client.trainModel(epochs);
if (!Status.SUCCESS.equals(tag)) {
failed("[train] unsolved error code in <client.trainModel>", ResponseCode.RequestError);
}
client.saveModel(flParameter.getTrainModelPath());
Common.freeSession();
return status;
}
/**
* Define the training process.
*
* @return the status code corresponding to the response message.
*/
public FLClientStatus localTrain() {
LOGGER.info(Common.addTag("[train] ====================================global train epoch " + iteration +
"===================================="));
status = FLClientStatus.SUCCESS;
retCode = ResponseCode.SUCCEED;
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("[train] train in albert"));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
int tag = alTrainBert.trainModel(flParameter.getTrainModelPath(), epochs);
if (tag == -1) {
LOGGER.severe(Common.addTag("[train] unsolved error code in <alTrainBert.trainModel>"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
}
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("[train] train in lenet"));
TrainLenet trainLenet = TrainLenet.getInstance();
int tag = trainLenet.trainModel(flParameter.getTrainModelPath(), epochs);
if (tag == -1) {
LOGGER.severe(Common.addTag("[train] unsolved error code in <trainLenet.trainModel>"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
}
if (Common.checkFLName(flParameter.getFlName())) {
status = deprecatedTrainLoop();
} else {
LOGGER.severe(Common.addTag("[train] the flName is not valid"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
status = trainLoop();
}
return status;
}
@ -357,6 +365,7 @@ public class FLLiteClient {
status = FLClientStatus.RESTART;
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
retCode = ResponseCode.OutOfTime;
return status;
}
LOGGER.info(Common.addTag("[updateModel] the response message length: " + message.length));
@ -364,16 +373,14 @@ public class FLLiteClient {
ByteBuffer debugBuffer = ByteBuffer.wrap(message);
ResponseUpdateModel responseDataBuf = ResponseUpdateModel.getRootAsResponseUpdateModel(debugBuffer);
status = updateModelBuf.doResponse(responseDataBuf);
retCode = responseDataBuf.retcode();
retCode = updateModelBuf.getRetCode();
if (status == FLClientStatus.RESTART) {
nextRequestTime = responseDataBuf.nextReqTime();
}
LOGGER.info(Common.addTag("[updateModel] get response from server ok!"));
} catch (IOException e) {
LOGGER.severe(Common.addTag("[updateModel] unsolved error code in updateModel: catch IOException: " +
e.getMessage()));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
failed("[updateModel] unsolved error code in updateModel: catch IOException: " + e.getMessage(),
ResponseCode.RequestError);
}
return status;
}
@ -396,6 +403,7 @@ public class FLLiteClient {
LOGGER.info(Common.addTag("[getModel] the server is not ready now, need wait some time and request " +
"again"));
status = FLClientStatus.WAIT;
retCode = ResponseCode.SucNotReady;
return status;
}
LOGGER.info(Common.addTag("[getModel] the response message length: " + message.length));
@ -404,19 +412,35 @@ public class FLLiteClient {
ByteBuffer debugBuffer = ByteBuffer.wrap(message);
ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer);
status = getModelBuf.doResponse(responseDataBuf);
retCode = responseDataBuf.retcode();
retCode = getModelBuf.getRetCode();
if (status == FLClientStatus.RESTART) {
nextRequestTime = responseDataBuf.timestamp();
}
LOGGER.info(Common.addTag("[getModel] get response from server ok!"));
} catch (IOException e) {
LOGGER.severe(Common.addTag("[getModel] un sloved error code: catch IOException: " + e.getMessage()));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
failed("[getModel] un sloved error code: catch IOException: " + e.getMessage(), ResponseCode.RequestError);
}
return status;
}
private Map<String, float[]> getFeatureMap() {
Map<String, float[]> featureMap = new HashMap<>();
if (Common.checkFLName(flParameter.getFlName())) {
featureMap = deprecatedGetFeatureMap();
return featureMap;
}
status = Common.initSession(flParameter.getTrainModelPath());
if (status == FLClientStatus.FAILED) {
Common.freeSession();
retCode = ResponseCode.RequestError;
return new HashMap<>();
}
List<MSTensor> features = client.getFeatures();
featureMap = CommonUtils.convertTensorToFeatures(features);
Common.freeSession();
return featureMap;
}
/**
* Obtain the weight of the model before training.
*
@ -456,6 +480,58 @@ public class FLLiteClient {
return mapBeforeTrain;
}
private void getOldFeatureMap() {
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel();
if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
Map<String, float[]> featureMap = getFeatureMap();
oldFeatureMap = getOldMapCopy(featureMap);
}
}
public void updateDpNormClip() {
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel();
if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
Map<String, float[]> fedFeatureMap = getFeatureMap();
float fedWeightUpdateNorm = calWeightUpdateNorm(oldFeatureMap, fedFeatureMap);
if (fedWeightUpdateNorm == -1) {
LOGGER.severe(Common.addTag("[updateDpNormClip] the returned value fedWeightUpdateNorm is not valid: " +
"-1, please check!"));
throw new IllegalArgumentException();
}
LOGGER.info(Common.addTag("[DP] L2-norm of weights' average update is: " + fedWeightUpdateNorm));
float newNormCLip = (float) getDpNormClipFactor() * fedWeightUpdateNorm;
if (iteration == 1) {
setDpNormClipAdapt(newNormCLip);
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
} else {
if (newNormCLip < getDpNormClipAdapt()) {
setDpNormClipAdapt(newNormCLip);
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
}
}
LOGGER.info(Common.addTag("[DP] Adaptive dpNormClip is: " + getDpNormClipAdapt()));
}
}
private float calWeightUpdateNorm(Map<String, float[]> originalData, Map<String, float[]> newData) {
float updateL2Norm = 0f;
for (String key : originalData.keySet()) {
float[] data = originalData.get(key);
float[] dataAfterUpdate = newData.get(key);
for (int j = 0; j < data.length; j++) {
if (j >= dataAfterUpdate.length) {
LOGGER.severe("[calWeightUpdateNorm] the index j is out of range for array dataAfterUpdate, " +
"please check");
return -1;
}
float updateData = data[j] - dataAfterUpdate[j];
updateL2Norm += updateData * updateData;
}
}
updateL2Norm = (float) Math.sqrt(updateL2Norm);
return updateL2Norm;
}
/**
* Obtain pairwise mask and individual mask.
*
@ -477,16 +553,14 @@ public class FLLiteClient {
localFLParameter.getEncryptLevel().toString() + "> : " + curStatus));
return curStatus;
case DP_ENCRYPT:
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals(ALBERT)) {
AlTrainBert alTrainBert = AlTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
// get the feature map before train
getOldFeatureMap();
if (oldFeatureMap.isEmpty()) {
LOGGER.severe(Common.addTag("[Encrypt] the return map in getOldFeatureMapis empty "));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
Map<String, float[]> copyMap = getOldMapCopy(map);
curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClipAdapt, copyMap);
curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClipAdapt, oldFeatureMap);
retCode = ResponseCode.SUCCEED;
if (curStatus != FLClientStatus.SUCCESS) {
LOGGER.info(Common.addTag("---Differential privacy init failed---"));
@ -537,85 +611,50 @@ public class FLLiteClient {
}
}
private FLClientStatus evaluateLoop() {
client.free();
status = FLClientStatus.SUCCESS;
retCode = ResponseCode.SUCCEED;
float acc = 0;
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig());
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath()));
acc = client.evalModel();
} else {
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
client.initSessionAndInputs(flParameter.getTrainModelPath(), localFLParameter.getMsConfig());
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getTrainModelPath()));
acc = client.evalModel();
}
if (Float.isNaN(acc)) {
failed("[evaluate] unsolved error code in <evalModel>: the return acc is NAN", ResponseCode.RequestError);
return status;
}
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
return status;
}
private void failed(String log, int retCode) {
LOGGER.severe(Common.addTag(log));
status = FLClientStatus.FAILED;
this.retCode = retCode;
}
/**
* Evaluate model after getting model from server.
*
* @return the status code in client.
*/
public FLClientStatus evaluateModel() {
status = FLClientStatus.SUCCESS;
retCode = ResponseCode.SUCCEED;
LOGGER.info(Common.addTag("===================================evaluate model after getting model from " +
"server==================================="));
if (flParameter.getFlName().equals(ALBERT)) {
float acc = 0;
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
AlInferBert alInferBert = AlInferBert.getInstance();
int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(),
flParameter.getIdsFile(), true);
if (dataSize <= 0) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alInferBert.initDataSet>: the " +
"return dataSize<=0"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
}
acc = alInferBert.evalModel();
} else {
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
int dataSize = alTrainBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(),
flParameter.getIdsFile());
if (dataSize <= 0) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alTrainBert.initDataSet>: the " +
"return dataSize<=0"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
}
acc = alTrainBert.evalModel();
}
if (Float.isNaN(acc)) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <evalModel>: the return acc is NAN"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
}
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " +
flParameter.getTestDataset() + " vocabFile: " + flParameter.getVocabFile() +
" idsFile: " + flParameter.getIdsFile()));
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
if (flParameter.getTestDataset().split(",").length < 2) {
LOGGER.severe(Common.addTag("[evaluate] the set testDataPath for lenet is not valid, should be the " +
"format of <data.bin,label.bin> "));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
}
int dataSize = trainLenet.initDataSet(flParameter.getTestDataset().split(",")[0],
flParameter.getTestDataset().split(",")[1]);
if (dataSize <= 0) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.initDataSet>: the return " +
"dataSize<=0"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
}
float acc = trainLenet.evalModel();
if (Float.isNaN(acc)) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.evalModel>: the return acc" +
" is NAN"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
}
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " +
flParameter.getTestDataset().split(",")[0] + " labelPath: " +
flParameter.getTestDataset().split(",")[1]));
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
if (Common.checkFLName(flParameter.getFlName())) {
status = deprecatedEvaluateLoop();
} else {
status = evaluateLoop();
}
return status;
}
@ -623,10 +662,43 @@ public class FLLiteClient {
/**
* Set date path.
*
* @param dataPath, train or test dataset and label set.
* @return date size.
*/
public int setInput(String dataPath) {
public int setInput() {
int dataSize = 0;
if (Common.checkFLName(flParameter.getFlName())) {
dataSize = deprecatedSetInput(flParameter.getTrainDataset());
return dataSize;
}
retCode = ResponseCode.SUCCEED;
LOGGER.info(Common.addTag("==========set input==========="));
// train
dataSize = client.initDataSets(flParameter.getDataMap()).get(RunType.TRAINMODE);
if (dataSize <= 0) {
retCode = ResponseCode.RequestError;
return -1;
}
return dataSize;
}
private int deprecatedSetBatchSize(int batchSize) {
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for AlTrainBert: " + batchSize));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
alTrainBert.setBatchSize(batchSize);
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for TrainLenet: " + batchSize));
TrainLenet trainLenet = TrainLenet.getInstance();
trainLenet.setBatchSize(batchSize);
} else {
LOGGER.severe(Common.addTag("[startFLJob] the ServerMod returned from server is not valid"));
return -1;
}
return 0;
}
private int deprecatedSetInput(String dataPath) {
retCode = ResponseCode.SUCCEED;
LOGGER.info(Common.addTag("==========set input==========="));
int dataSize = 0;
@ -653,61 +725,139 @@ public class FLLiteClient {
return dataSize;
}
/**
* Initialization session.
*
* @return the status code in client.
*/
public FLClientStatus initSession() {
int tag = 0;
private FLClientStatus deprecatedTrainLoop() {
retCode = ResponseCode.SUCCEED;
status = Common.initSession(flParameter.getTrainModelPath());
if (status == FLClientStatus.FAILED) {
retCode = ResponseCode.RequestError;
return status;
}
status = FLClientStatus.SUCCESS;
retCode = ResponseCode.SUCCEED;
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " +
"Train Session============="));
LOGGER.info(Common.addTag("[train] train in albert"));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
int tag = alTrainBert.trainModel(flParameter.getTrainModelPath(), epochs);
if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return " +
"is -1"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
failed("[train] unsolved error code in <alTrainBert.trainModel>", ResponseCode.RequestError);
}
LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " " +
"Create inference Session============="));
AlInferBert alInferBert = AlInferBert.getInstance();
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " +
"Train Session============="));
LOGGER.info(Common.addTag("[train] train in lenet"));
TrainLenet trainLenet = TrainLenet.getInstance();
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
int tag = trainLenet.trainModel(flParameter.getTrainModelPath(), epochs);
if (tag == -1) {
failed("[train] unsolved error code in <trainLenet.trainModel>", ResponseCode.RequestError);
}
} else {
failed("[train] the flName is not valid", ResponseCode.RequestError);
}
if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is " +
"-1"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
return FLClientStatus.SUCCESS;
Common.freeSession();
return status;
}
/**
* Free session.
*/
protected void freeSession() {
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("===========free train session============="));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
SessionUtil.free(alTrainBert.getTrainSession());
if (!flParameter.getTestDataset().equals("null")) {
LOGGER.info(Common.addTag("===========free inference session============="));
AlInferBert alInferBert = AlInferBert.getInstance();
SessionUtil.free(alInferBert.getTrainSession());
}
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("===========free session============="));
TrainLenet trainLenet = TrainLenet.getInstance();
SessionUtil.free(trainLenet.getTrainSession());
private Map<String, float[]> deprecatedGetFeatureMap() {
status = Common.initSession(flParameter.getTrainModelPath());
if (status == FLClientStatus.FAILED) {
Common.freeSession();
retCode = ResponseCode.RequestError;
return new HashMap<>();
}
Map<String, float[]> featureMap = new HashMap<>();
if (flParameter.getFlName().equals(ALBERT)) {
AlTrainBert alTrainBert = AlTrainBert.getInstance();
featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
}
Common.freeSession();
return featureMap;
}
private FLClientStatus deprecatedEvaluateLoop() {
status = FLClientStatus.SUCCESS;
retCode = ResponseCode.SUCCEED;
if (flParameter.getFlName().equals(ALBERT)) {
float acc = 0;
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
AlInferBert alInferBert = AlInferBert.getInstance();
int tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
if (tag == -1) {
failed("[evaluate] unsolved error code in <initSessionAndInputs>: the return is -1",
ResponseCode.RequestError);
return FLClientStatus.FAILED;
}
int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(),
flParameter.getIdsFile(), true);
if (dataSize <= 0) {
failed("[evaluate] unsolved error code in <alInferBert.initDataSet>: the return dataSize<=0",
ResponseCode.RequestError);
return status;
}
acc = alInferBert.evalModel();
SessionUtil.free(alInferBert.getTrainSession());
} else {
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
int tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), false);
if (tag == -1) {
failed("[evaluate] unsolved error code in <initSessionAndInputs>: the return is -1",
ResponseCode.RequestError);
return FLClientStatus.FAILED;
}
int dataSize = alTrainBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(),
flParameter.getIdsFile());
if (dataSize <= 0) {
failed("[evaluate] unsolved error code in <alTrainBert.initDataSet>: the return dataSize<=0",
ResponseCode.RequestError);
return status;
}
acc = alTrainBert.evalModel();
SessionUtil.free(alTrainBert.getTrainSession());
}
if (Float.isNaN(acc)) {
failed("[evaluate] unsolved error code in <evalModel>: the return acc is NAN",
ResponseCode.RequestError);
return status;
}
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " +
flParameter.getTestDataset() + " vocabFile: " + flParameter.getVocabFile() +
" idsFile: " + flParameter.getIdsFile()));
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
if (flParameter.getTestDataset().split(",").length < 2) {
failed("[evaluate] the set testDataPath for lenet is not valid, should be the format of <data.bin," +
"label.bin>", ResponseCode.RequestError);
return status;
}
int tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
if (tag == -1) {
failed("[evaluate] unsolved error code in <initSessionAndInputs>: the return is -1",
ResponseCode.RequestError);
return FLClientStatus.FAILED;
}
int dataSize = trainLenet.initDataSet(flParameter.getTestDataset().split(",")[0],
flParameter.getTestDataset().split(",")[1]);
if (dataSize <= 0) {
failed("[evaluate] unsolved error code in <trainLenet.initDataSet>: the return dataSize<=0",
ResponseCode.RequestError);
return status;
}
float acc = trainLenet.evalModel();
SessionUtil.free(trainLenet.getTrainSession());
if (Float.isNaN(acc)) {
failed("[evaluate] unsolved error code in <trainLenet.evalModel>: the return acc is NAN",
ResponseCode.RequestError);
return status;
}
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " +
flParameter.getTestDataset().split(",")[0] + " labelPath: " +
flParameter.getTestDataset().split(",")[1]));
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
}
return status;
}
}

View File

@ -18,10 +18,19 @@ package com.mindspore.flclient;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import com.mindspore.flclient.model.RunType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.logging.Logger;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.X509TrustManager;
/**
* Defines global parameters used during federated learning and these parameters are provided for users to set.
*
@ -41,21 +50,39 @@ public class FLParameter {
public static final int SLEEP_TIME = 1000;
private static volatile FLParameter flParameter;
private String deployEnv;
private String domainName;
private String certPath;
private SSLSocketFactory sslSocketFactory;
private X509TrustManager x509TrustManager;
private IFLJobResultCallback iflJobResultCallback = new FLJobResultCallback();
private String trainDataset;
private String vocabFile = "null";
private String idsFile = "null";
private String testDataset = "null";
private boolean useSSL = false;
private String flName;
private String trainModelPath;
private String inferModelPath;
private String sslProtocol = "TLSv1.2";
private String clientID;
private boolean useSSL = false;
private int timeOut;
private int sleepTime;
private boolean ifUseElb = false;
private int serverNum = 1;
private boolean ifPkiVerify = false;
private String equipCrlPath = "null";
private long validIterInterval = 3600000L;
private int threadNum = 1;
private BindMode cpuBindMode = BindMode.NOT_BINDING_CORE;
private List<String> trainWeightName = new ArrayList<>();
private List<String> inferWeightName = new ArrayList<>();
private Map<RunType, List<String>> dataMap = new HashMap<>();
private ServerMod serverMod;
private int batchSize;
private FLParameter() {
clientID = UUID.randomUUID().toString();
@ -79,10 +106,28 @@ public class FLParameter {
return localRef;
}
public String getDeployEnv() {
if (deployEnv == null || deployEnv.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <env> is null, please set it before using"));
throw new IllegalArgumentException();
}
return deployEnv;
}
public void setDeployEnv(String env) {
if (Common.checkEnv(env)) {
this.deployEnv = env;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <env> is not in envTrustList: x86, android, " +
"please check it before setting"));
throw new IllegalArgumentException();
}
}
public String getDomainName() {
if (domainName == null || domainName.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <domainName> is null or empty, please set it " +
"before use"));
"before using"));
throw new IllegalArgumentException();
}
return domainName;
@ -91,16 +136,33 @@ public class FLParameter {
public void setDomainName(String domainName) {
if (domainName == null || domainName.isEmpty() || (!("https".equals(domainName.split(":")[0]) || "http".equals(domainName.split(":")[0])))) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <domainName> is not valid, it should be like " +
"as https://...... or http://......, please check it before set"));
"as https://...... or http://......, please check it before setting"));
throw new IllegalArgumentException();
}
this.domainName = domainName;
}
public String getClientID() {
if (clientID == null || clientID.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <clientID> is null or empty, please check"));
throw new IllegalArgumentException();
}
return clientID;
}
public void setClientID(String clientID) {
if (clientID == null || clientID.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <clientID> is null or empty, please check " +
"before setting"));
throw new IllegalArgumentException();
}
this.clientID = clientID;
}
public String getCertPath() {
if (certPath == null || certPath.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is null or empty, please set it " +
"before use"));
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is null or empty, the <certPath> " +
"must be set when conducting https communication, please set it by FLParameter.setCertPath()"));
throw new IllegalArgumentException();
}
return certPath;
@ -111,28 +173,70 @@ public class FLParameter {
if (Common.checkPath(realCertPath)) {
this.certPath = realCertPath;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is not exist, please check it " +
"before set"));
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> does not exist, it must be a valid" +
" path when conducting https communication, please check it before setting"));
throw new IllegalArgumentException();
}
}
public SSLSocketFactory getSslSocketFactory() {
if (sslSocketFactory == null) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <sslSocketFactory> is null, the " +
"<sslSocketFactory> must be set when the deployEnv being \"android\", please set it by " +
"FLParameter.setSslSocketFactory()"));
throw new IllegalArgumentException();
}
return sslSocketFactory;
}
public void setSslSocketFactory(SSLSocketFactory sslSocketFactory) {
this.sslSocketFactory = sslSocketFactory;
}
public X509TrustManager getX509TrustManager() {
if (x509TrustManager == null) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <x509TrustManager> is null, the " +
"<x509TrustManager> must be set when the deployEnv being \"android\", please set it by " +
"FLParameter.setX509TrustManager()"));
throw new IllegalArgumentException();
}
return x509TrustManager;
}
public void setX509TrustManager(X509TrustManager x509TrustManager) {
this.x509TrustManager = x509TrustManager;
}
public IFLJobResultCallback getIflJobResultCallback() {
if (iflJobResultCallback == null) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <iflJobResultCallback> is null, please set it" +
" before using"));
throw new IllegalArgumentException();
}
return iflJobResultCallback;
}
public void setIflJobResultCallback(IFLJobResultCallback iflJobResultCallback) {
this.iflJobResultCallback = iflJobResultCallback;
}
public String getTrainDataset() {
if (trainDataset == null || trainDataset.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is null or empty, please set " +
"it before use"));
"it before using"));
throw new IllegalArgumentException();
}
return trainDataset;
}
public void setTrainDataset(String trainDataset) {
LOGGER.warning(Common.addTag(Common.LOG_DEPRECATED));
String realTrainDataset = Common.getRealPath(trainDataset);
if (Common.checkPath(realTrainDataset)) {
this.trainDataset = realTrainDataset;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is not exist, please check it " +
"before set"));
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> does not exist, please check " +
"it before setting"));
throw new IllegalArgumentException();
}
}
@ -140,38 +244,41 @@ public class FLParameter {
public String getVocabFile() {
if ("null".equals(vocabFile) && ALBERT.equals(flName)) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is null, please set it before " +
"use"));
"using"));
throw new IllegalArgumentException();
}
return vocabFile;
}
public void setVocabFile(String vocabFile) {
LOGGER.warning(Common.addTag(Common.LOG_DEPRECATED));
String realVocabFile = Common.getRealPath(vocabFile);
if (Common.checkPath(realVocabFile)) {
this.vocabFile = realVocabFile;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is not exist, please check it " +
"before set"));
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> does not exist, please check it " +
"before setting"));
throw new IllegalArgumentException();
}
}
public String getIdsFile() {
if ("null".equals(idsFile) && ALBERT.equals(flName)) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is null, please set it before use"));
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is null, please set it before " +
"using"));
throw new IllegalArgumentException();
}
return idsFile;
}
public void setIdsFile(String idsFile) {
LOGGER.warning(Common.addTag(Common.LOG_DEPRECATED));
String realIdsFile = Common.getRealPath(idsFile);
if (Common.checkPath(realIdsFile)) {
this.idsFile = realIdsFile;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is not exist, please check it " +
"before set"));
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> does not exist, please check it " +
"before setting"));
throw new IllegalArgumentException();
}
}
@ -181,12 +288,13 @@ public class FLParameter {
}
public void setTestDataset(String testDataset) {
LOGGER.warning(Common.addTag(Common.LOG_DEPRECATED));
String realTestDataset = Common.getRealPath(testDataset);
if (Common.checkPath(realTestDataset)) {
this.testDataset = realTestDataset;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <testDataset> is not exist, please check it " +
"before set"));
LOGGER.severe(Common.addTag("[flParameter] the parameter of <testDataset> does not exist, please check it" +
" before setting"));
throw new IllegalArgumentException();
}
}
@ -194,27 +302,20 @@ public class FLParameter {
public String getFlName() {
if (flName == null || flName.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is null or empty, please set it " +
"before use"));
"before using"));
throw new IllegalArgumentException();
}
return flName;
}
public void setFlName(String flName) {
if (Common.checkFLName(flName)) {
this.flName = flName;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is not in FL_NAME_TRUST_LIST: " +
Arrays.toString(Common.FL_NAME_TRUST_LIST.toArray(new String[0])) + ", please check it before " +
"set"));
throw new IllegalArgumentException();
}
this.flName = flName;
}
public String getTrainModelPath() {
if (trainModelPath == null || trainModelPath.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is null or empty, please set" +
" it before use"));
" it before using"));
throw new IllegalArgumentException();
}
return trainModelPath;
@ -225,8 +326,8 @@ public class FLParameter {
if (Common.checkPath(realTrainModelPath)) {
this.trainModelPath = realTrainModelPath;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is not exist, please check " +
"it before set"));
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> does not exist, please " +
"check it before setting"));
throw new IllegalArgumentException();
}
}
@ -234,7 +335,7 @@ public class FLParameter {
public String getInferModelPath() {
if (inferModelPath == null || inferModelPath.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is null or empty, please set" +
" it before use"));
" it before using"));
throw new IllegalArgumentException();
}
return inferModelPath;
@ -245,8 +346,8 @@ public class FLParameter {
if (Common.checkPath(realInferModelPath)) {
this.inferModelPath = realInferModelPath;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is not exist, please check " +
"it before set"));
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> does not exist, please check" +
" it before setting"));
throw new IllegalArgumentException();
}
}
@ -256,9 +357,31 @@ public class FLParameter {
}
public void setUseSSL(boolean useSSL) {
LOGGER.warning(Common.addTag("Certificate authentication is required for https communicationthis parameter " +
"is true by default and no need to set it, " + Common.LOG_DEPRECATED));
this.useSSL = useSSL;
}
public String getSslProtocol() {
if (sslProtocol == null || sslProtocol.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <sslProtocol> is null or empty, please set it" +
" before using"));
throw new IllegalArgumentException();
}
return sslProtocol;
}
public void setSslProtocol(String sslProtocol) {
if (Common.checkSSLProtocol(sslProtocol)) {
this.sslProtocol = sslProtocol;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <sslProtocol> is not in sslProtocolTrustList " +
": " + Arrays.toString(Common.SSL_PROTOCOL_TRUST_LIST.toArray(new String[0])) + ", please check " +
"it before setting"));
throw new IllegalArgumentException();
}
}
public int getTimeOut() {
return timeOut;
}
@ -286,7 +409,7 @@ public class FLParameter {
public int getServerNum() {
if (serverNum <= 0) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <serverNum> <= 0, it should be > 0, please " +
"set it before use"));
"set it before using"));
throw new IllegalArgumentException();
}
return serverNum;
@ -296,11 +419,159 @@ public class FLParameter {
this.serverNum = serverNum;
}
public String getClientID() {
if (clientID == null || clientID.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <clientID> is null or empty, please check"));
public boolean isPkiVerify() {
return ifPkiVerify;
}
public void setPkiVerify(boolean ifPkiVerify) {
this.ifPkiVerify = ifPkiVerify;
}
public String getEquipCrlPath() {
if (equipCrlPath == null || equipCrlPath.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <equipCrlPath> is null or empty, please set " +
"it before using"));
throw new IllegalArgumentException();
}
return clientID;
return equipCrlPath;
}
/**
* Obtains the Valid Iteration Interval set by a user.
*
* @return the Valid Iteration Interval.
*/
public long getValidInterval() {
if (validIterInterval <= 0) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <validIterInterval> is not valid, please set " +
"it as larger than 0."));
throw new IllegalArgumentException();
}
return validIterInterval;
}
public void setEquipCrlPath(String certPath) {
String realCertPath = Common.getRealPath(certPath);
if (Common.checkPath(realCertPath)) {
this.equipCrlPath = realCertPath;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <equipCrlPath> does not exist, please check " +
"it before setting"));
throw new IllegalArgumentException();
}
}
/**
* Set the Valid Iteration Interval.
*
* @param validInterval the Valid Iteration Interval.
*/
public void setValidInterval(long validInterval) {
if (validInterval > 0) {
this.validIterInterval = validInterval;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <validIterInterval> should be larger than 0, " +
"please set it again."));
throw new IllegalArgumentException();
}
}
public int getThreadNum() {
return threadNum;
}
public void setThreadNum(int threadNum) {
if (threadNum <= 0) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <threadNum> <= 0, please check it before " +
"setting"));
throw new IllegalArgumentException();
}
this.threadNum = threadNum;
}
public int getCpuBindMode() {
LOGGER.info(Common.addTag("[flParameter] the parameter of <cpuBindMode> is: " + cpuBindMode.toString() + " , " +
"the NOT_BINDING_CORE means that not binding core, BIND_LARGE_CORE means binding the large core, " +
"BIND_MIDDLE_CORE means binding the middle core"));
return cpuBindMode.ordinal();
}
public void setCpuBindMode(BindMode cpuBindMode) {
this.cpuBindMode = cpuBindMode;
}
public void setHybridWeightName(List<String> hybridWeightName, RunType runType) {
if (RunType.TRAINMODE.equals(runType)) {
this.trainWeightName = hybridWeightName;
} else if (RunType.INFERMODE.equals(runType)) {
this.inferWeightName = hybridWeightName;
} else {
LOGGER.severe(Common.addTag("[flParameter] the variable <runType> can only be set to <RunType.TRAINMODE> " +
"or <RunType.INFERMODE>, please check it"));
throw new IllegalArgumentException();
}
}
public List<String> getHybridWeightName(RunType runType) {
if (RunType.TRAINMODE.equals(runType)) {
if (trainWeightName.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainWeightName> is null, please " +
"set it before use"));
throw new IllegalArgumentException();
}
return trainWeightName;
} else if (RunType.INFERMODE.equals(runType)) {
if (inferWeightName.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferWeightName> is null, please " +
"set it before use"));
throw new IllegalArgumentException();
}
return inferWeightName;
} else {
LOGGER.severe(Common.addTag("[flParameter] the variable <runType> can only be set to <RunType.TRAINMODE> " +
"or <RunType.INFERMODE>, please check it"));
throw new IllegalArgumentException();
}
}
public Map<RunType, List<String>> getDataMap() {
if (dataMap.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <dataMaps> is null, please " +
"set it before use"));
throw new IllegalArgumentException();
}
return dataMap;
}
public void setDataMap(Map<RunType, List<String>> dataMap) {
this.dataMap = dataMap;
}
public ServerMod getServerMod() {
if (serverMod == null) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <serverMod> is null, please " +
"set it before use"));
throw new IllegalArgumentException();
}
return serverMod;
}
public void setServerMod(ServerMod serverMod) {
this.serverMod = serverMod;
}
public int getBatchSize() {
return batchSize;
}
public void setBatchSize(int batchSize) {
if (batchSize <= 0) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <batchSize> <= 0, please check it before " +
"setting"));
throw new IllegalArgumentException();
}
this.batchSize = batchSize;
}
}

View File

@ -16,16 +16,17 @@
package com.mindspore.flclient;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.Client;
import com.mindspore.flclient.model.ClientManager;
import com.mindspore.flclient.model.RunType;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import com.mindspore.flclient.model.Status;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.FeatureMap;
import mindspore.schema.RequestGetModel;
import mindspore.schema.ResponseCode;
@ -35,6 +36,9 @@ import java.util.ArrayList;
import java.util.Date;
import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
/**
* Define the serialization method, handle the response message returned from server for getModel request.
*
@ -50,6 +54,7 @@ public class GetModel {
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private int retCode = ResponseCode.RequestError;
private GetModel() {
}
@ -72,6 +77,10 @@ public class GetModel {
return localRef;
}
public int getRetCode() {
return retCode;
}
/**
* Get a flatBuffer builder of RequestGetModel.
*
@ -88,8 +97,8 @@ public class GetModel {
return builder.iteration(iteration).flName(name).time().build();
}
private FLClientStatus parseResponseAlbert(ResponseGetModel responseDataBuf) {
private FLClientStatus deprecatedParseResponseAlbert(ResponseGetModel responseDataBuf) {
FLClientStatus status;
int fmCount = responseDataBuf.featureMapLength();
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
@ -113,6 +122,11 @@ public class GetModel {
LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength:" +
" " + feature.dataLength()));
}
status = Common.initSession(flParameter.getInferModelPath());
if (status == FLClientStatus.FAILED) {
retCode = ResponseCode.RequestError;
return status;
}
int tag = 0;
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference " +
"model-----------------"));
@ -131,6 +145,7 @@ public class GetModel {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
Common.freeSession();
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
@ -145,6 +160,11 @@ public class GetModel {
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " +
feature.dataLength()));
}
status = Common.initSession(flParameter.getInferModelPath());
if (status == FLClientStatus.FAILED) {
retCode = ResponseCode.RequestError;
return status;
}
int tag = 0;
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
@ -154,11 +174,13 @@ public class GetModel {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
Common.freeSession();
}
return FLClientStatus.SUCCESS;
}
private FLClientStatus parseResponseLenet(ResponseGetModel responseDataBuf) {
private FLClientStatus deprecatedParseResponseLenet(ResponseGetModel responseDataBuf) {
FLClientStatus status;
int fmCount = responseDataBuf.featureMapLength();
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
@ -172,6 +194,11 @@ public class GetModel {
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " +
feature.dataLength()));
}
status = Common.initSession(flParameter.getInferModelPath());
if (status == FLClientStatus.FAILED) {
retCode = ResponseCode.RequestError;
return status;
}
int tag = 0;
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
TrainLenet trainLenet = TrainLenet.getInstance();
@ -180,6 +207,116 @@ public class GetModel {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
Common.freeSession();
return FLClientStatus.SUCCESS;
}
private FLClientStatus deprecatedParseFeatures(ResponseGetModel responseDataBuf) {
FLClientStatus status = FLClientStatus.SUCCESS;
if (ALBERT.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
status = deprecatedParseResponseAlbert(responseDataBuf);
} else if (LENET.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>"));
status = deprecatedParseResponseLenet(responseDataBuf);
} else {
LOGGER.severe(Common.addTag("[getModel] the flName is not valid, only support: lenet, albert"));
status = FLClientStatus.FAILED;
}
return status;
}
private FLClientStatus parseResponseFeatures(ResponseGetModel responseDataBuf) {
FLClientStatus status;
Client client = ClientManager.getClient(flParameter.getFlName());
int fmCount = responseDataBuf.featureMapLength();
if (fmCount <= 0) {
LOGGER.severe(Common.addTag("[getModel] the feature size get from server is zero"));
retCode = ResponseCode.SystemError;
return FLClientStatus.FAILED;
}
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[getModel] parseResponseFeatures by " + localFLParameter.getServerMod()));
ArrayList<FeatureMap> trainFeatureMaps = new ArrayList<FeatureMap>();
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = responseDataBuf.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
retCode = ResponseCode.SystemError;
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname();
if (flParameter.getHybridWeightName(RunType.TRAINMODE).contains(featureName)) {
trainFeatureMaps.add(feature);
LOGGER.info(Common.addTag("[getModel] trainWeightFullname: " + feature.weightFullname() + ", " +
"trainWeightLength: " + feature.dataLength()));
}
if (flParameter.getHybridWeightName(RunType.INFERMODE).contains(featureName)) {
inferFeatureMaps.add(feature);
LOGGER.info(Common.addTag("[getModel] inferWeightFullname: " + feature.weightFullname() + ", " +
"inferWeightLength: " + feature.dataLength()));
}
}
Status tag;
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference " +
"model-----------------"));
status = Common.initSession(flParameter.getInferModelPath());
if (status == FLClientStatus.FAILED) {
retCode = ResponseCode.RequestError;
return status;
}
tag = client.updateFeatures(flParameter.getInferModelPath(), inferFeatureMaps);
Common.freeSession();
if (!Status.SUCCESS.equals(tag)) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <Client.updateFeatures>"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into train model-----------------"));
status = Common.initSession(flParameter.getTrainModelPath());
if (status == FLClientStatus.FAILED) {
retCode = ResponseCode.RequestError;
return status;
}
tag = client.updateFeatures(flParameter.getTrainModelPath(), inferFeatureMaps);
Common.freeSession();
if (!Status.SUCCESS.equals(tag)) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <Client.updateFeatures>"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
LOGGER.info(Common.addTag("[getModel] parseResponseFeatures by " + localFLParameter.getServerMod()));
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = responseDataBuf.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
retCode = ResponseCode.SystemError;
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname();
featureMaps.add(feature);
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", " +
"weightLength: " + feature.dataLength()));
}
Status tag;
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
status = Common.initSession(flParameter.getTrainModelPath());
if (status == FLClientStatus.FAILED) {
retCode = ResponseCode.RequestError;
return status;
}
tag = client.updateFeatures(flParameter.getTrainModelPath(), featureMaps);
LOGGER.info(Common.addTag("[getModel] ===========free session============="));
Common.freeSession();
if (!Status.SUCCESS.equals(tag)) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <Client.updateFeatures>"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
}
return FLClientStatus.SUCCESS;
}
@ -190,25 +327,21 @@ public class GetModel {
* @return the status code corresponding to the response message.
*/
public FLClientStatus doResponse(ResponseGetModel responseDataBuf) {
LOGGER.info(Common.addTag("[getModel] ==========get model content is:================"));
LOGGER.info(Common.addTag("[getModel] ==========retCode: " + responseDataBuf.retcode()));
retCode = responseDataBuf.retcode();
LOGGER.info(Common.addTag("[getModel] ==========the response message of getModel is:================"));
LOGGER.info(Common.addTag("[getModel] ==========retCode: " + retCode));
LOGGER.info(Common.addTag("[getModel] ==========reason: " + responseDataBuf.reason()));
LOGGER.info(Common.addTag("[getModel] ==========iteration: " + responseDataBuf.iteration()));
LOGGER.info(Common.addTag("[getModel] ==========time: " + responseDataBuf.timestamp()));
FLClientStatus status = FLClientStatus.SUCCESS;
int retCode = responseDataBuf.retcode();
switch (retCode) {
switch (responseDataBuf.retcode()) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[getModel] getModel response success"));
if (ALBERT.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
status = parseResponseAlbert(responseDataBuf);
} else if (LENET.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>"));
status = parseResponseLenet(responseDataBuf);
if (Common.checkFLName(flParameter.getFlName())) {
status = deprecatedParseFeatures(responseDataBuf);
} else {
LOGGER.severe(Common.addTag("[getModel] the flName is not valid, only support: lenet, albert"));
throw new IllegalArgumentException();
LOGGER.info(Common.addTag("[getModel] into <parseResponseFeatures>"));
status = parseResponseFeatures(responseDataBuf);
}
return status;
case (ResponseCode.SucNotReady):

View File

@ -16,6 +16,8 @@
package com.mindspore.flclient;
import com.mindspore.lite.config.MSConfig;
import org.bouncycastle.math.ec.rfc7748.X25519;
import java.util.ArrayList;
@ -59,6 +61,16 @@ public class LocalFLParameter {
* The model name supported by federated learning tasks: "albert".
*/
public static final String ALBERT = "albert";
/**
* The deployment environment supported by federated learning tasks: "android".
*/
public static final String ANDROID = "android";
/**
* The deployment environment supported by federated learning tasks: "x86".
*/
public static final String X86 = "x86";
private static volatile LocalFLParameter localFLParameter;
private List<String> classifierWeightName = new ArrayList<>();
@ -67,6 +79,9 @@ public class LocalFLParameter {
private String encryptLevel = EncryptLevel.NOT_ENCRYPT.toString();
private String earlyStopMod = EarlyStopMod.NOT_EARLY_STOP.toString();
private String serverMod = ServerMod.HYBRID_TRAINING.toString();
private boolean stopJobFlag = false;
private MSConfig msConfig = new MSConfig();
private boolean useSSL = true;
private LocalFLParameter() {
// set classifierWeightName albertWeightName
@ -143,14 +158,14 @@ public class LocalFLParameter {
public void setEncryptLevel(String encryptLevel) {
if (encryptLevel == null || encryptLevel.isEmpty()) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <encryptLevel> is null, please check it " +
"before set"));
"before setting"));
throw new IllegalArgumentException();
}
if ((!EncryptLevel.DP_ENCRYPT.toString().equals(encryptLevel)) &&
(!EncryptLevel.NOT_ENCRYPT.toString().equals(encryptLevel)) &&
(!EncryptLevel.PW_ENCRYPT.toString().equals(encryptLevel))) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <encryptLevel> is " + encryptLevel + " ," +
" it must be DP_ENCRYPT or NOT_ENCRYPT or PW_ENCRYPT, please check it before set"));
" it must be DP_ENCRYPT or NOT_ENCRYPT or PW_ENCRYPT, please check it before setting"));
throw new IllegalArgumentException();
}
this.encryptLevel = encryptLevel;
@ -163,7 +178,7 @@ public class LocalFLParameter {
public void setEarlyStopMod(String earlyStopMod) {
if (earlyStopMod == null || earlyStopMod.isEmpty()) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <earlyStopMod> is null, please check it " +
"before set"));
"before setting"));
throw new IllegalArgumentException();
}
if ((!EarlyStopMod.NOT_EARLY_STOP.toString().equals(earlyStopMod)) &&
@ -171,7 +186,8 @@ public class LocalFLParameter {
(!EarlyStopMod.LOSS_DIFF.toString().equals(earlyStopMod)) &&
(!EarlyStopMod.WEIGHT_DIFF.toString().equals(earlyStopMod))) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <earlyStopMod> is " + earlyStopMod + " ," +
" it must be NOT_EARLY_STOP or LOSS_ABS or LOSS_DIFF or WEIGHT_DIFF, please check it before set"));
" it must be NOT_EARLY_STOP or LOSS_ABS or LOSS_DIFF or WEIGHT_DIFF, please check it before " +
"setting"));
throw new IllegalArgumentException();
}
this.earlyStopMod = earlyStopMod;
@ -184,15 +200,44 @@ public class LocalFLParameter {
public void setServerMod(String serverMod) {
if (serverMod == null || serverMod.isEmpty()) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <serverMod> is null, please check it " +
"before set"));
"before setting"));
throw new IllegalArgumentException();
}
if ((!ServerMod.HYBRID_TRAINING.toString().equals(serverMod)) &&
(!ServerMod.FEDERATED_LEARNING.toString().equals(serverMod))) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <serverMod> is " + serverMod + " , it " +
"must be HYBRID_TRAINING or FEDERATED_LEARNING, please check it before set"));
"must be HYBRID_TRAINING or FEDERATED_LEARNING, please check it before setting"));
throw new IllegalArgumentException();
}
this.serverMod = serverMod;
}
public boolean isStopJobFlag() {
return stopJobFlag;
}
public void setStopJobFlag(boolean stopJobFlag) {
this.stopJobFlag = stopJobFlag;
}
public MSConfig getMsConfig() {
return msConfig;
}
public void setMsConfig(int DeviceType, int threadNum, int cpuBindMode, boolean enable_fp16) {
// arg 0: DeviceType:DT_CPU -> 0
// arg 1: ThreadNum -> 2
// arg 2: cpuBindMode:NO_BIND -> 0
// arg 3: enable_fp16 -> false
msConfig.init(DeviceType, threadNum, cpuBindMode, enable_fp16);
}
public boolean isUseSSL() {
return useSSL;
}
public void setUseSSL(boolean useSSL) {
this.useSSL = useSSL;
}
}

View File

@ -68,7 +68,7 @@ public class SSLSocketFactoryTools {
String domainName = flParameter.getDomainName();
if ((domainName == null || domainName.isEmpty() || domainName.split("//").length < 2)) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the <domainName> is null or not valid, it should" +
" be like as https://...... , please check!"));
" be like https://...... , please check!"));
throw new IllegalArgumentException();
}
if (domainName.split("//")[1].split(":").length < 2) {
@ -87,7 +87,7 @@ public class SSLSocketFactoryTools {
private void initSslSocketFactory() {
try {
sslContext = SSLContext.getInstance("TLS");
sslContext = SSLContext.getInstance(flParameter.getSslProtocol());
x509Certificate = readCert(flParameter.getCertPath());
myTrustManager = new MyTrustManager(x509Certificate);
sslContext.init(null, new TrustManager[]{
@ -137,9 +137,8 @@ public class SSLSocketFactoryTools {
"convert to X509Certificate"));
}
} catch (FileNotFoundException | CertificateException ex) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch FileNotFoundException or CertificateException " +
"when creating " +
"CertificateFactory in readCert: " + ex.getMessage()));
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch exception when creating CertificateFactory in " +
"readCert: invalid file or CertificateException"));
} finally {
try {
if (inputStream != null) {

View File

@ -16,20 +16,12 @@
package com.mindspore.flclient;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.FeatureMap;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;
@ -52,7 +44,7 @@ public class SecureProtocol {
private double dpEps;
private double dpDelta;
private double dpNormClip;
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
private ArrayList<String> updateFeatureName = new ArrayList<String>();
private int retCode;
/**
@ -115,17 +107,17 @@ public class SecureProtocol {
*
* @return the feature names that needed to be encrypted.
*/
public ArrayList<String> getEncryptFeatureName() {
return encryptFeatureName;
public ArrayList<String> getUpdateFeatureName() {
return updateFeatureName;
}
/**
* Set the parameter encryptFeatureName.
* Set the parameter updateFeatureName.
*
* @param encryptFeatureName the feature names that needed to be encrypted.
* @param updateFeatureName the feature names that needed to be encrypted.
*/
public void setEncryptFeatureName(ArrayList<String> encryptFeatureName) {
this.encryptFeatureName = encryptFeatureName;
public void setUpdateFeatureName(ArrayList<String> updateFeatureName) {
this.updateFeatureName = updateFeatureName;
}
/**
@ -146,6 +138,10 @@ public class SecureProtocol {
LOGGER.info(String.format("[PairWiseMask] ==============request flID: %s ==============",
localFLParameter.getFlID()));
// round 0
if (localFLParameter.isStopJobFlag()) {
LOGGER.info(Common.addTag("the stopJObFlag is set to true, the job will be stop"));
return status;
}
status = cipherClient.exchangeKeys();
retCode = cipherClient.getRetCode();
LOGGER.info(String.format("[PairWiseMask] ============= RequestExchangeKeys+GetExchangeKeys response: %s ",
@ -154,6 +150,10 @@ public class SecureProtocol {
return status;
}
// round 1
if (localFLParameter.isStopJobFlag()) {
LOGGER.info(Common.addTag("the stopJObFlag is set to true, the job will be stop"));
return status;
}
status = cipherClient.shareSecrets();
retCode = cipherClient.getRetCode();
LOGGER.info(String.format("[Encrypt] =============RequestShareSecrets+GetShareSecrets response: %s ",
@ -162,6 +162,10 @@ public class SecureProtocol {
return status;
}
// round2
if (localFLParameter.isStopJobFlag()) {
LOGGER.info(Common.addTag("the stopJObFlag is set to true, the job will be stop"));
return status;
}
featureMask = cipherClient.doubleMaskingWeight();
if (featureMask == null || featureMask.length <= 0) {
LOGGER.severe(Common.addTag("[Encrypt] the returned featureMask from cipherClient.doubleMaskingWeight" +
@ -180,30 +184,18 @@ public class SecureProtocol {
* @param trainDataSize trainDataSize tne size of train data set.
* @return the serialized model weights after adding masks.
*/
public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize) {
public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize, Map<String, float[]> trainedMap) {
if (featureMask == null || featureMask.length == 0) {
LOGGER.severe("[Encrypt] feature mask is null, please check");
return new int[0];
}
LOGGER.info(String.format("[Encrypt] feature mask size: %s", featureMask.length));
// get feature map
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals(ALBERT)) {
AlTrainBert alTrainBert = AlTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
} else {
LOGGER.severe(Common.addTag("[Encrypt] the flName is not valid, only support: lenet, albert"));
throw new IllegalArgumentException();
}
int featureSize = encryptFeatureName.size();
int featureSize = updateFeatureName.size();
int[] featuresMap = new int[featureSize];
int maskIndex = 0;
for (int i = 0; i < featureSize; i++) {
String key = encryptFeatureName.get(i);
float[] data = map.get(key);
String key = updateFeatureName.get(i);
float[] data = trainedMap.get(key);
LOGGER.info(String.format("[Encrypt] feature name: %s feature size: %s", key, data.length));
for (int j = 0; j < data.length; j++) {
float rawData = data[j];
@ -349,21 +341,10 @@ public class SecureProtocol {
* @param trainDataSize tne size of train data set.
* @return the serialized model weights after adding masks.
*/
public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize) {
public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize, Map<String, float[]> trainedMap) {
// get feature map
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals(ALBERT)) {
AlTrainBert alTrainBert = AlTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
} else {
LOGGER.severe(Common.addTag("[Encrypt] the flName is not valid, only support: lenet, albert"));
throw new IllegalArgumentException();
}
Map<String, float[]> mapBeforeTrain = modelMap;
int featureSize = encryptFeatureName.size();
int featureSize = updateFeatureName.size();
// calculate sigma
double gaussianSigma = calculateSigma(dpNormClip, dpEps, dpDelta);
LOGGER.info(Common.addTag("[Encrypt] =============Noise sigma of DP is: " + gaussianSigma + "============="));
@ -371,8 +352,8 @@ public class SecureProtocol {
// calculate l2-norm of all layers' update array
double updateL2Norm = 0d;
for (int i = 0; i < featureSize; i++) {
String key = encryptFeatureName.get(i);
float[] data = map.get(key);
String key = updateFeatureName.get(i);
float[] data = trainedMap.get(key);
float[] dataBeforeTrain = mapBeforeTrain.get(key);
for (int j = 0; j < data.length; j++) {
float rawData = data[j];
@ -391,12 +372,12 @@ public class SecureProtocol {
// clip and add noise
int[] featuresMap = new int[featureSize];
for (int i = 0; i < featureSize; i++) {
String key = encryptFeatureName.get(i);
if (!map.containsKey(key)) {
String key = updateFeatureName.get(i);
if (!trainedMap.containsKey(key)) {
LOGGER.severe("[Encrypt] the key: " + key + " is not in map, please check!");
return new int[0];
}
float[] data = map.get(key);
float[] data = trainedMap.get(key);
float[] data2 = new float[data.length];
if (!mapBeforeTrain.containsKey(key)) {
LOGGER.severe("[Encrypt] the key: " + key + " is not in mapBeforeTrain, please check!");

View File

@ -1,29 +1,21 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.Client;
import com.mindspore.flclient.model.ClientManager;
import com.mindspore.flclient.model.RunType;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.Status;
import com.mindspore.flclient.model.TrainLenet;
import com.mindspore.flclient.pki.PkiBean;
import com.mindspore.flclient.pki.PkiUtil;
import mindspore.schema.FLPlan;
import mindspore.schema.FeatureMap;
@ -31,9 +23,14 @@ import mindspore.schema.RequestFLJob;
import mindspore.schema.ResponseCode;
import mindspore.schema.ResponseFLJob;
import java.io.IOException;
import java.security.cert.Certificate;
import java.util.ArrayList;
import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
/**
* StartFLJob
*
@ -52,11 +49,425 @@ public class StartFLJob {
private int featureSize;
private String nextRequestTime;
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
private ArrayList<String> updateFeatureName = new ArrayList<String>();
private int retCode = ResponseCode.RequestError;
private float lr = (float) 0.1;
private int batchSize;
private StartFLJob() {
}
/**
* getInstance of StartFLJob
*
* @return StartFLJob instance
*/
public static StartFLJob getInstance() {
StartFLJob localRef = startFLJob;
if (localRef == null) {
synchronized (StartFLJob.class) {
localRef = startFLJob;
if (localRef == null) {
startFLJob = localRef = new StartFLJob();
}
}
}
return localRef;
}
public String getNextRequestTime() {
return nextRequestTime;
}
public int getRetCode() {
return retCode;
}
/**
* get request start FLJob
*
* @param dataSize dataSize
* @param iteration iteration
* @param time time
* @param pkiBean pki bean
* @return byte[] data
*/
public byte[] getRequestStartFLJob(int dataSize, int iteration, long time, PkiBean pkiBean) {
RequestStartFLJobBuilder builder = new RequestStartFLJobBuilder();
if (flParameter.isPkiVerify()) {
if (pkiBean == null) {
LOGGER.severe(Common.addTag("[startFLJob] the parameter of <pkiBean> is null, please check!"));
throw new IllegalArgumentException();
}
return builder.flName(flParameter.getFlName())
.time(time)
.id(localFLParameter.getFlID())
.dataSize(dataSize)
.iteration(iteration)
.signData(pkiBean.getSignData())
.certificateChain(pkiBean.getCertificates())
.build();
}
return builder.flName(flParameter.getFlName())
.time(time)
.id(localFLParameter.getFlID())
.dataSize(dataSize)
.iteration(iteration)
.build();
}
public int getFeatureSize() {
return featureSize;
}
public ArrayList<String> getUpdateFeatureName() {
return updateFeatureName;
}
private FLClientStatus deprecatedParseResponseAlbert(ResponseFLJob flJob) {
FLClientStatus status;
int fmCount = flJob.featureMapLength();
updateFeatureName.clear();
if (fmCount <= 0) {
LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero"));
return FLClientStatus.FAILED;
}
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[startFLJob] parseResponseAlbert by " + localFLParameter.getServerMod()));
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname();
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
albertFeatureMaps.add(feature);
inferFeatureMaps.add(feature);
featureSize += feature.dataLength();
updateFeatureName.add(feature.weightFullname());
} else if (localFLParameter.getClassifierWeightName().contains(featureName)) {
inferFeatureMaps.add(feature);
} else {
continue;
}
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " +
"weightLength: " + feature.dataLength()));
}
status = Common.initSession(flParameter.getTrainModelPath());
if (status == FLClientStatus.FAILED) {
retCode = ResponseCode.RequestError;
return status;
}
int tag = 0;
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference " +
"model-----------------"));
AlInferBert alInferBert = AlInferBert.getInstance();
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(),
inferFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------"));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
albertFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
Common.freeSession();
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
LOGGER.info(Common.addTag("[startFLJob] parseResponseAlbert by " + localFLParameter.getServerMod()));
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname();
featureMaps.add(feature);
featureSize += feature.dataLength();
updateFeatureName.add(featureName);
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " +
"weightLength: " + feature.dataLength()));
}
status = Common.initSession(flParameter.getTrainModelPath());
if (status == FLClientStatus.FAILED) {
retCode = ResponseCode.RequestError;
return status;
}
int tag = 0;
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
featureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
Common.freeSession();
}
return FLClientStatus.SUCCESS;
}
private FLClientStatus deprecatedParseResponseLenet(ResponseFLJob flJob) {
FLClientStatus status;
int fmCount = flJob.featureMapLength();
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
updateFeatureName.clear();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname();
featureMaps.add(feature);
featureSize += feature.dataLength();
updateFeatureName.add(featureName);
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " +
feature.weightFullname() + ", weightLength: " + feature.dataLength()));
}
status = Common.initSession(flParameter.getTrainModelPath());
if (status == FLClientStatus.FAILED) {
retCode = ResponseCode.RequestError;
return status;
}
int tag = 0;
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
TrainLenet trainLenet = TrainLenet.getInstance();
tag = SessionUtil.updateFeatures(trainLenet.getTrainSession(), flParameter.getTrainModelPath(), featureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
Common.freeSession();
return FLClientStatus.SUCCESS;
}
private FLClientStatus hybridFeatures(ResponseFLJob flJob) {
FLClientStatus status;
Client client = ClientManager.getClient(flParameter.getFlName());
int fmCount = flJob.featureMapLength();
ArrayList<FeatureMap> trainFeatureMaps = new ArrayList<FeatureMap>();
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
retCode = ResponseCode.SystemError;
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname();
if (flParameter.getHybridWeightName(RunType.TRAINMODE).contains(featureName)) {
trainFeatureMaps.add(feature);
featureSize += feature.dataLength();
updateFeatureName.add(feature.weightFullname());
LOGGER.info(Common.addTag("[startFLJob] trainWeightFullname: " + feature.weightFullname() + ", " +
"trainWeightLength: " + feature.dataLength()));
}
if (flParameter.getHybridWeightName(RunType.INFERMODE).contains(featureName)) {
inferFeatureMaps.add(feature);
LOGGER.info(Common.addTag("[startFLJob] inferWeightFullname: " + feature.weightFullname() + ", " +
"inferWeightLength: " + feature.dataLength()));
}
}
Status tag;
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference " +
"model-----------------"));
status = Common.initSession(flParameter.getInferModelPath());
if (status == FLClientStatus.FAILED) {
retCode = ResponseCode.RequestError;
return status;
}
tag = client.updateFeatures(flParameter.getInferModelPath(), inferFeatureMaps);
Common.freeSession();
if (!Status.SUCCESS.equals(tag)) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <Client.updateFeatures>"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------"));
status = Common.initSession(flParameter.getTrainModelPath());
if (status == FLClientStatus.FAILED) {
retCode = ResponseCode.RequestError;
return status;
}
LOGGER.info(Common.addTag("[startFLJob] set <lr> for client: " + lr));
tag = client.setLearningRate(lr);
if (!Status.SUCCESS.equals(tag)) {
LOGGER.severe(Common.addTag("[startFLJob] setLearningRate failed, return -1, please check"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[startFLJob] set <batch size> for client: " + batchSize));
client.setBatchSize(batchSize);
tag = client.updateFeatures(flParameter.getTrainModelPath(), inferFeatureMaps);
Common.freeSession();
if (!Status.SUCCESS.equals(tag)) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <Client.updateFeatures>"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
return status;
}
private FLClientStatus normalFeatures(ResponseFLJob flJob) {
FLClientStatus status;
Client client = ClientManager.getClient(flParameter.getFlName());
int fmCount = flJob.featureMapLength();
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
retCode = ResponseCode.SystemError;
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname();
featureMaps.add(feature);
featureSize += feature.dataLength();
updateFeatureName.add(featureName);
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " +
"weightLength: " + feature.dataLength()));
}
Status tag;
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
status = Common.initSession(flParameter.getTrainModelPath());
if (status == FLClientStatus.FAILED) {
retCode = ResponseCode.RequestError;
return status;
}
LOGGER.info(Common.addTag("[startFLJob] set <lr> for client: " + lr));
tag = client.setLearningRate(lr);
if (!Status.SUCCESS.equals(tag)) {
LOGGER.severe(Common.addTag("[startFLJob] setLearningRate failed, return -1, please check"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[startFLJob] set <batch size> for client: " + batchSize));
client.setBatchSize(batchSize);
tag = client.updateFeatures(flParameter.getTrainModelPath(), featureMaps);
LOGGER.info(Common.addTag("[startFLJob] ===========free session============="));
Common.freeSession();
if (!Status.SUCCESS.equals(tag)) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <Client.updateFeatures>"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
return status;
}
private FLClientStatus parseResponseFeatures(ResponseFLJob flJob) {
FLClientStatus status;
int fmCount = flJob.featureMapLength();
updateFeatureName.clear();
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[startFLJob] parseResponseFeatures by " + localFLParameter.getServerMod()));
status = hybridFeatures(flJob);
if (status == FLClientStatus.FAILED) {
return status;
}
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
LOGGER.info(Common.addTag("[startFLJob] parseResponseFeatures by " + localFLParameter.getServerMod()));
status = normalFeatures(flJob);
if (status == FLClientStatus.FAILED) {
return status;
}
}
return FLClientStatus.SUCCESS;
}
private FLClientStatus deprecatedParseFeatures(ResponseFLJob flJob) {
FLClientStatus status = FLClientStatus.SUCCESS;
if (ALBERT.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseAlbert>"));
status = deprecatedParseResponseAlbert(flJob);
}
if (LENET.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseLenet>"));
status = deprecatedParseResponseLenet(flJob);
}
return status;
}
/**
* response res
*
* @param flJob ResponseFLJob
* @return FLClientStatus
*/
public FLClientStatus doResponse(ResponseFLJob flJob) {
if (flJob == null) {
LOGGER.severe(Common.addTag("[startFLJob] the input parameter flJob is null"));
retCode = ResponseCode.SystemError;
return FLClientStatus.FAILED;
}
FLPlan flPlanConfig = flJob.flPlanConfig();
if (flPlanConfig == null) {
LOGGER.severe(Common.addTag("[startFLJob] the flPlanConfig is null"));
retCode = ResponseCode.SystemError;
return FLClientStatus.FAILED;
}
if (flJob.featureMapLength() <= 0) {
LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero"));
retCode = ResponseCode.SystemError;
return FLClientStatus.FAILED;
}
retCode = flJob.retcode();
LOGGER.info(Common.addTag("[startFLJob] ==========the response message of startFLJob is:================"));
LOGGER.info(Common.addTag("[startFLJob] return retCode: " + retCode));
LOGGER.info(Common.addTag("[startFLJob] reason: " + flJob.reason()));
LOGGER.info(Common.addTag("[startFLJob] iteration: " + flJob.iteration()));
LOGGER.info(Common.addTag("[startFLJob] is selected: " + flJob.isSelected()));
LOGGER.info(Common.addTag("[startFLJob] next request time: " + flJob.nextReqTime()));
nextRequestTime = flJob.nextReqTime();
LOGGER.info(Common.addTag("[startFLJob] timestamp: " + flJob.timestamp()));
FLClientStatus status;
int responseRetCode = flJob.retcode();
switch (responseRetCode) {
case (ResponseCode.SUCCEED):
localFLParameter.setServerMod(flPlanConfig.serverMode());
if (flPlanConfig.lr() != 0) {
lr = flPlanConfig.lr();
} else {
LOGGER.info(Common.addTag("[startFLJob] the GlobalParameter <lr> from server: " + lr + " is not " +
"valid, " +
"will use the default value 0.1"));
}
batchSize = flPlanConfig.miniBatch();
if (Common.checkFLName(flParameter.getFlName())) {
status = deprecatedParseFeatures(flJob);
} else {
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseFeatures>"));
status = parseResponseFeatures(flJob);
}
return status;
case (ResponseCode.OutOfTime):
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
case (ResponseCode.SystemError):
LOGGER.info(Common.addTag("[startFLJob] catch RequestError or SystemError"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[startFLJob] the return <retCode> from server is invalid: " + retCode));
return FLClientStatus.FAILED;
}
}
class RequestStartFLJobBuilder {
private RequestFLJob requestFLJob;
private FlatBufferBuilder builder;
@ -65,6 +476,11 @@ public class StartFLJob {
private int dataSize = 0;
private int timestampOffset = 0;
private int idOffset = 0;
private int signDataOffset = 0;
private int keyAttestationOffset = 0;
private int equipCertOffset = 0;
private int equipCACertOffset = 0;
private int rootCertOffset = 0;
public RequestStartFLJobBuilder() {
builder = new FlatBufferBuilder();
@ -135,6 +551,50 @@ public class StartFLJob {
return this;
}
/**
* signData
*
* @param signData byte[]
* @return RequestStartFLJobBuilder
*/
public RequestStartFLJobBuilder signData(byte[] signData) {
if (signData == null || signData.length == 0) {
LOGGER.severe(
Common.addTag("[startFLJob] the parameter of <signData> is null or empty, please check!"));
throw new IllegalArgumentException();
}
this.signDataOffset = RequestFLJob.createSignDataVector(builder, signData);
return this;
}
/**
* set certificateChain
*
* @param certificates Certificate array
* @return RequestStartFLJobBuilder
*/
public RequestStartFLJobBuilder certificateChain(Certificate[] certificates) {
if (certificates == null || certificates.length < 4) {
LOGGER.severe(Common.addTag("[startFLJob] the parameter of <certificates> is null or the length "
+ "is not valid (should be >= 4), please check!"));
throw new IllegalArgumentException();
}
try {
String keyAttestationPem = PkiUtil.getPemFormat(certificates[0]);
String equipCertPem = PkiUtil.getPemFormat(certificates[1]);
String equipCACertPem = PkiUtil.getPemFormat(certificates[2]);
String rootCertPem = PkiUtil.getPemFormat(certificates[3]);
this.keyAttestationOffset = this.builder.createString(keyAttestationPem);
this.equipCertOffset = this.builder.createString(equipCertPem);
this.equipCACertOffset = this.builder.createString(equipCACertPem);
this.rootCertOffset = this.builder.createString(rootCertPem);
} catch (IOException e) {
LOGGER.severe(Common.addTag("[StartFLJob] catch IOException in certificateChain: " + e.getMessage()));
}
return this;
}
/**
* build protobuffer
*
@ -152,210 +612,4 @@ public class StartFLJob {
return builder.sizedByteArray();
}
}
/**
* getInstance of StartFLJob
*
* @return StartFLJob instance
*/
public static StartFLJob getInstance() {
StartFLJob localRef = startFLJob;
if (localRef == null) {
synchronized (StartFLJob.class) {
localRef = startFLJob;
if (localRef == null) {
startFLJob = localRef = new StartFLJob();
}
}
}
return localRef;
}
public String getNextRequestTime() {
return nextRequestTime;
}
/**
* get request start FLJob
*
* @param dataSize dataSize
* @param iteration iteration
* @param time time
* @return byte[] data
*/
public byte[] getRequestStartFLJob(int dataSize, int iteration, long time) {
RequestStartFLJobBuilder builder = new RequestStartFLJobBuilder();
return builder.flName(flParameter.getFlName())
.time(time)
.id(localFLParameter.getFlID())
.dataSize(dataSize)
.iteration(iteration)
.build();
}
public int getFeatureSize() {
return featureSize;
}
public ArrayList<String> getEncryptFeatureName() {
return encryptFeatureName;
}
private FLClientStatus parseResponseAlbert(ResponseFLJob flJob) {
int fmCount = flJob.featureMapLength();
encryptFeatureName.clear();
if (fmCount <= 0) {
LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero"));
return FLClientStatus.FAILED;
}
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[startFLJob] parseResponseAlbert by " + localFLParameter.getServerMod()));
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname();
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
albertFeatureMaps.add(feature);
inferFeatureMaps.add(feature);
featureSize += feature.dataLength();
encryptFeatureName.add(feature.weightFullname());
} else if (localFLParameter.getClassifierWeightName().contains(featureName)) {
inferFeatureMaps.add(feature);
} else {
continue;
}
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " +
"weightLength: " + feature.dataLength()));
}
int tag = 0;
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference " +
"model-----------------"));
AlInferBert alInferBert = AlInferBert.getInstance();
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(),
inferFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------"));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
albertFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
LOGGER.info(Common.addTag("[startFLJob] parseResponseAlbert by " + localFLParameter.getServerMod()));
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname();
featureMaps.add(feature);
featureSize += feature.dataLength();
encryptFeatureName.add(featureName);
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " +
"weightLength: " + feature.dataLength()));
}
int tag = 0;
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
featureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
}
return FLClientStatus.SUCCESS;
}
private FLClientStatus parseResponseLenet(ResponseFLJob flJob) {
int fmCount = flJob.featureMapLength();
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
encryptFeatureName.clear();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname();
featureMaps.add(feature);
featureSize += feature.dataLength();
encryptFeatureName.add(featureName);
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " +
feature.weightFullname() + ", weightLength: " + feature.dataLength()));
}
int tag = 0;
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
TrainLenet trainLenet = TrainLenet.getInstance();
tag = SessionUtil.updateFeatures(trainLenet.getTrainSession(), flParameter.getTrainModelPath(), featureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
return FLClientStatus.SUCCESS;
}
/**
* response res
*
* @param flJob ResponseFLJob
* @return FLClientStatus
*/
public FLClientStatus doResponse(ResponseFLJob flJob) {
if (flJob == null) {
LOGGER.severe(Common.addTag("[startFLJob] the input parameter flJob is null"));
return FLClientStatus.FAILED;
}
FLPlan flPlanConfig = flJob.flPlanConfig();
if (flPlanConfig == null) {
LOGGER.severe(Common.addTag("[startFLJob] the flPlanConfig is null"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[startFLJob] return retCode: " + flJob.retcode()));
LOGGER.info(Common.addTag("[startFLJob] reason: " + flJob.reason()));
LOGGER.info(Common.addTag("[startFLJob] iteration: " + flJob.iteration()));
LOGGER.info(Common.addTag("[startFLJob] is selected: " + flJob.isSelected()));
LOGGER.info(Common.addTag("[startFLJob] next request time: " + flJob.nextReqTime()));
nextRequestTime = flJob.nextReqTime();
LOGGER.info(Common.addTag("[startFLJob] timestamp: " + flJob.timestamp()));
FLClientStatus status = FLClientStatus.SUCCESS;
int retCode = flJob.retcode();
switch (retCode) {
case (ResponseCode.SUCCEED):
localFLParameter.setServerMod(flPlanConfig.serverMode());
if (ALBERT.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseAlbert>"));
status = parseResponseAlbert(flJob);
}
if (LENET.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseLenet>"));
status = parseResponseLenet(flJob);
}
return status;
case (ResponseCode.OutOfTime):
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
case (ResponseCode.SystemError):
LOGGER.info(Common.addTag("[startFLJob] catch RequestError or SystemError"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[startFLJob] the return <retCode> from server is invalid: " + retCode));
return FLClientStatus.FAILED;
}
}
}

View File

@ -18,19 +18,26 @@ package com.mindspore.flclient;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.ANDROID;
import static com.mindspore.flclient.LocalFLParameter.LENET;
import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.Client;
import com.mindspore.flclient.model.ClientManager;
import com.mindspore.flclient.model.RunType;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.Status;
import com.mindspore.flclient.model.TrainLenet;
import com.mindspore.flclient.pki.PkiUtil;
import com.mindspore.lite.config.CpuBindMode;
import mindspore.schema.ResponseGetModel;
import java.nio.ByteBuffer;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
@ -46,8 +53,38 @@ public class SyncFLJob {
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private FLJobResultCallback flJobResultCallback = new FLJobResultCallback();
private Map<String, float[]> oldFeatureMap;
private IFLJobResultCallback flJobResultCallback;
private FLClientStatus curStatus;
private void initFlIDForPkiVerify() {
if (flParameter.isPkiVerify()) {
LOGGER.info(Common.addTag("pkiVerify mode is open!"));
String equipCertHash = PkiUtil.genEquipCertHash(flParameter.getClientID());
if (equipCertHash == null || equipCertHash.isEmpty()) {
LOGGER.severe(Common.addTag("equipCertHash is empty, please check your mobile phone, only Huawei " +
"phones are supported now."));
throw new IllegalArgumentException();
}
LOGGER.info(Common.addTag("flID for pki verify is: " + equipCertHash));
localFLParameter.setFlID(equipCertHash);
} else {
LOGGER.info(Common.addTag("pkiVerify mode is not open!"));
localFLParameter.setFlID(flParameter.getClientID());
}
}
public SyncFLJob() {
if (!Common.checkFLName(flParameter.getFlName())) {
try {
LOGGER.info(Common.addTag("the flName: " + flParameter.getFlName()));
Class.forName(flParameter.getFlName());
} catch (ClassNotFoundException e) {
LOGGER.severe(Common.addTag("catch ClassNotFoundException error, the set flName does not exist, please " +
"check: " + e.getMessage()));
throw new IllegalArgumentException();
}
}
}
/**
* Starts a federated learning task on the device.
@ -55,209 +92,172 @@ public class SyncFLJob {
* @return the status code corresponding to the response message.
*/
public FLClientStatus flJobRun() {
Common.setSecureRandom(new SecureRandom());
localFLParameter.setFlID(flParameter.getClientID());
FLLiteClient client = new FLLiteClient();
FLClientStatus curStatus;
curStatus = client.initSession();
if (curStatus == FLClientStatus.FAILED) {
LOGGER.severe(Common.addTag("init session failed"));
flJobResultCallback.onFlJobFinished(flParameter.getFlName(), client.getIterations(), client.getRetCode());
return curStatus;
flJobResultCallback = flParameter.getIflJobResultCallback();
if (!Common.checkFLName(flParameter.getFlName()) && ANDROID.equals(flParameter.getDeployEnv())) {
Common.setSecureRandom(Common.getFastSecureRandom());
} else {
Common.setSecureRandom(new SecureRandom());
}
initFlIDForPkiVerify();
localFLParameter.setMsConfig(0, flParameter.getThreadNum(), flParameter.getCpuBindMode(), false);
FLLiteClient flLiteClient = new FLLiteClient();
LOGGER.info(Common.addTag("recovery StopJobFlag to false in the start of fl job"));
localFLParameter.setStopJobFlag(false);
do {
LOGGER.info(Common.addTag("flName: " + flParameter.getFlName()));
int trainDataSize = client.setInput(flParameter.getTrainDataset());
if (trainDataSize <= 0) {
LOGGER.severe(Common.addTag("unsolved error code in <client.setInput>: the return trainDataSize<=0"));
curStatus = FLClientStatus.FAILED;
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(),
client.getRetCode());
if (checkStopJobFlag()) {
break;
}
client.setTrainDataSize(trainDataSize);
LOGGER.info(Common.addTag("flName: " + flParameter.getFlName()));
int trainDataSize = flLiteClient.setInput();
if (trainDataSize <= 0) {
curStatus = FLClientStatus.FAILED;
failed("unsolved error code in <flLiteClient.setInput>: the return trainDataSize<=0, setInput",
flLiteClient);
break;
}
flLiteClient.setTrainDataSize(trainDataSize);
// startFLJob
curStatus = startFLJob(client);
curStatus = startFLJob(flLiteClient);
if (curStatus == FLClientStatus.RESTART) {
restart("[startFLJob]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
restart("[startFLJob]", flLiteClient.getNextRequestTime(), flLiteClient);
continue;
} else if (curStatus == FLClientStatus.FAILED) {
failed("[startFLJob]", client.getIteration(), client.getRetCode(), curStatus);
failed("[startFLJob]", flLiteClient);
break;
}
LOGGER.info(Common.addTag("[startFLJob] startFLJob succeed, curIteration: " + client.getIteration()));
// get the feature map before train
getOldFeatureMap(client);
LOGGER.info(Common.addTag("[startFLJob] startFLJob succeed, curIteration: " + flLiteClient.getIteration()));
// create mask
curStatus = client.getFeatureMask();
curStatus = flLiteClient.getFeatureMask();
if (curStatus == FLClientStatus.RESTART) {
restart("[Encrypt] creatMask", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
restart("[Encrypt] creatMask", flLiteClient.getNextRequestTime(), flLiteClient);
continue;
} else if (curStatus == FLClientStatus.FAILED) {
failed("[Encrypt] createMask", client.getIteration(), client.getRetCode(), curStatus);
failed("[Encrypt] createMask", flLiteClient);
break;
}
// train
curStatus = client.localTrain();
curStatus = flLiteClient.localTrain();
if (curStatus == FLClientStatus.FAILED) {
failed("[train] train", client.getIteration(), client.getRetCode(), curStatus);
failed("[train] train", flLiteClient);
break;
}
LOGGER.info(Common.addTag("[train] train succeed"));
// updateModel
curStatus = updateModel(client);
curStatus = updateModel(flLiteClient);
if (curStatus == FLClientStatus.RESTART) {
restart("[updateModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
restart("[updateModel]", flLiteClient.getNextRequestTime(), flLiteClient);
continue;
} else if (curStatus == FLClientStatus.FAILED) {
failed("[updateModel] updateModel", client.getIteration(), client.getRetCode(), curStatus);
failed("[updateModel] updateModel", flLiteClient);
break;
}
// unmasking
curStatus = client.unMasking();
curStatus = flLiteClient.unMasking();
if (curStatus == FLClientStatus.RESTART) {
restart("[Encrypt] unmasking", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
restart("[Encrypt] unmasking", flLiteClient.getNextRequestTime(), flLiteClient);
continue;
} else if (curStatus == FLClientStatus.FAILED) {
failed("[Encrypt] unmasking", client.getIteration(), client.getRetCode(), curStatus);
failed("[Encrypt] unmasking", flLiteClient);
break;
}
// getModel
curStatus = getModel(client);
curStatus = getModel(flLiteClient);
if (curStatus == FLClientStatus.RESTART) {
restart("[getModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
restart("[getModel]", flLiteClient.getNextRequestTime(), flLiteClient);
continue;
} else if (curStatus == FLClientStatus.FAILED) {
failed("[getModel] getModel", client.getIteration(), client.getRetCode(), curStatus);
failed("[getModel] getModel", flLiteClient);
break;
}
// get the feature map after averaging and update dp_norm_clip
updateDpNormClip(client);
flLiteClient.updateDpNormClip();
// evaluate model after getting model from server
if (flParameter.getTestDataset().equals("null")) {
LOGGER.info(Common.addTag("[evaluate] the testDataset is null, don't evaluate the model after getting" +
" model from server"));
if (!checkEvalPath()) {
LOGGER.info(Common.addTag("[evaluate] the data map set by user do not contain evaluation dataset, " +
"don't evaluate the model after getting model from server"));
} else {
curStatus = client.evaluateModel();
curStatus = flLiteClient.evaluateModel();
if (curStatus == FLClientStatus.FAILED) {
failed("[evaluate] evaluate", client.getIteration(), client.getRetCode(), curStatus);
failed("[evaluate] evaluate", flLiteClient);
break;
}
LOGGER.info(Common.addTag("[evaluate] evaluate succeed"));
}
LOGGER.info(Common.addTag("========================================================the total response of "
+ client.getIteration() + ": " + curStatus +
+ flLiteClient.getIteration() + ": " + curStatus +
"======================================================================"));
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(),
client.getRetCode());
} while (client.getIteration() < client.getIterations());
client.freeSession();
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), flLiteClient.getIteration(),
flLiteClient.getRetCode());
Common.freeSession();
} while (flLiteClient.getIteration() < flLiteClient.getIterations());
LOGGER.info(Common.addTag("flJobRun finish"));
flJobResultCallback.onFlJobFinished(flParameter.getFlName(), client.getIterations(), client.getRetCode());
flJobResultCallback.onFlJobFinished(flParameter.getFlName(), flLiteClient.getIterations(),
flLiteClient.getRetCode());
return curStatus;
}
private FLClientStatus startFLJob(FLLiteClient client) {
FLClientStatus curStatus = client.startFLJob();
private FLClientStatus startFLJob(FLLiteClient flLiteClient) {
FLClientStatus curStatus = flLiteClient.startFLJob();
while (curStatus == FLClientStatus.WAIT) {
waitSomeTime();
curStatus = client.startFLJob();
curStatus = flLiteClient.startFLJob();
}
return curStatus;
}
private FLClientStatus updateModel(FLLiteClient client) {
FLClientStatus curStatus = client.updateModel();
private FLClientStatus updateModel(FLLiteClient flLiteClient) {
FLClientStatus curStatus = flLiteClient.updateModel();
while (curStatus == FLClientStatus.WAIT) {
waitSomeTime();
curStatus = client.updateModel();
curStatus = flLiteClient.updateModel();
}
return curStatus;
}
private FLClientStatus getModel(FLLiteClient client) {
FLClientStatus curStatus = client.getModel();
private FLClientStatus getModel(FLLiteClient flLiteClient) {
FLClientStatus curStatus = flLiteClient.getModel();
while (curStatus == FLClientStatus.WAIT) {
waitSomeTime();
curStatus = client.getModel();
curStatus = flLiteClient.getModel();
}
return curStatus;
}
private void updateDpNormClip(FLLiteClient client) {
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel();
if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
int currentIter = client.getIteration();
Map<String, float[]> fedFeatureMap = getFeatureMap();
float fedWeightUpdateNorm = calWeightUpdateNorm(oldFeatureMap, fedFeatureMap);
if (fedWeightUpdateNorm == -1) {
LOGGER.severe(Common.addTag("[updateDpNormClip] the returned value fedWeightUpdateNorm is not valid: " +
"-1, please check!"));
throw new IllegalArgumentException();
private boolean checkEvalPath() {
boolean tag = true;
if (Common.checkFLName(flParameter.getFlName())) {
if ("null".equals(flParameter.getTestDataset())) {
tag = false;
}
LOGGER.info(Common.addTag("[DP] L2-norm of weights' average update is: " + fedWeightUpdateNorm));
float newNormCLip = (float) client.getDpNormClipFactor() * fedWeightUpdateNorm;
if (currentIter == 1) {
client.setDpNormClipAdapt(newNormCLip);
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
} else {
if (newNormCLip < client.getDpNormClipAdapt()) {
client.setDpNormClipAdapt(newNormCLip);
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
}
}
LOGGER.info(Common.addTag("[DP] Adaptive dpNormClip is: " + client.getDpNormClipAdapt()));
return tag;
}
if (!flParameter.getDataMap().containsKey(RunType.EVALMODE)) {
LOGGER.info(Common.addTag("[evaluate] the data map set by user do not contain evaluation dataset, " +
"don't evaluate the model after getting model from server"));
tag = false;
return tag;
}
return tag;
}
private boolean checkStopJobFlag() {
if (localFLParameter.isStopJobFlag()) {
LOGGER.info(Common.addTag("the stopJObFlag is set to true, the job will be stop"));
curStatus = FLClientStatus.FAILED;
return true;
} else {
return false;
}
}
private void getOldFeatureMap(FLLiteClient client) {
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel();
if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
Map<String, float[]> featureMap = getFeatureMap();
oldFeatureMap = client.getOldMapCopy(featureMap);
}
}
private float calWeightUpdateNorm(Map<String, float[]> originalData, Map<String, float[]> newData) {
float updateL2Norm = 0f;
for (String key : originalData.keySet()) {
float[] data = originalData.get(key);
float[] dataAfterUpdate = newData.get(key);
for (int j = 0; j < data.length; j++) {
if (j >= dataAfterUpdate.length) {
LOGGER.severe("[calWeightUpdateNorm] the index j is out of range for array dataAfterUpdate, " +
"please check");
return -1;
}
float updateData = data[j] - dataAfterUpdate[j];
updateL2Norm += updateData * updateData;
}
}
updateL2Norm = (float) Math.sqrt(updateL2Norm);
return updateL2Norm;
}
private Map<String, float[]> getFeatureMap() {
Map<String, float[]> featureMap = new HashMap<>();
if (flParameter.getFlName().equals(ALBERT)) {
AlTrainBert alTrainBert = AlTrainBert.getInstance();
featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
}
return featureMap;
}
/**
* Starts an inference task on the device.
@ -265,6 +265,292 @@ public class SyncFLJob {
* @return the status code corresponding to the response message.
*/
public int[] modelInference() {
if (Common.checkFLName(flParameter.getFlName())) {
return deprecatedModelInference();
}
Client client = ClientManager.getClient(flParameter.getFlName());
localFLParameter.setMsConfig(0, flParameter.getThreadNum(), flParameter.getCpuBindMode(), false);
localFLParameter.setStopJobFlag(false);
int[] labels = new int[0];
Map<RunType, Integer> dataSize = client.initDataSets(flParameter.getDataMap());
if (dataSize.isEmpty()) {
LOGGER.severe("[model inference] initDataSets failed, please check");
return new int[0];
}
Status tag = client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig());
if (!Status.SUCCESS.equals(tag)) {
LOGGER.severe(Common.addTag("[model inference] unsolved error code in <initSessionAndInputs>: the return " +
" status is: " + tag));
return new int[0];
}
client.setBatchSize(flParameter.getBatchSize());
LOGGER.info(Common.addTag("===========model inference============="));
labels = client.inferModel().stream().mapToInt(Integer::valueOf).toArray();
if (labels == null || labels.length == 0) {
LOGGER.severe("[model inference] the returned label from client.inferModel() is null, please " +
"check");
}
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
client.free();
LOGGER.info(Common.addTag("[model inference] inference finish"));
return labels;
}
/**
* Obtains the latest model on the cloud.
*
* @return the status code corresponding to the response message.
*/
public FLClientStatus getModel() {
if (!Common.checkFLName(flParameter.getFlName()) && ANDROID.equals(flParameter.getDeployEnv())) {
Common.setSecureRandom(Common.getFastSecureRandom());
} else {
Common.setSecureRandom(new SecureRandom());
}
if (Common.checkFLName(flParameter.getFlName())) {
return deprecatedGetModel();
}
localFLParameter.setServerMod(flParameter.getServerMod().toString());
localFLParameter.setMsConfig(0, 1, 0, false);
FLClientStatus status;
FLLiteClient flLiteClient = new FLLiteClient();
status = flLiteClient.getModel();
return status;
}
/**
* use to stop FL job.
*/
public void stopFLJob() {
LOGGER.info(Common.addTag("will stop the flJob"));
localFLParameter.setStopJobFlag(true);
Common.notifyObject();
}
private void waitSomeTime() {
if (flParameter.getSleepTime() != 0) {
Common.sleep(flParameter.getSleepTime());
} else {
Common.sleep(SLEEP_TIME);
}
}
private void waitNextReqTime(String nextReqTime) {
long waitTime = Common.getWaitTime(nextReqTime);
Common.sleep(waitTime);
}
private void restart(String tag, String nextReqTime, FLLiteClient flLiteClient) {
LOGGER.info(Common.addTag(tag + " out of time: need wait and request startFLJob again"));
waitNextReqTime(nextReqTime);
Common.freeSession();
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), flLiteClient.getIteration(),
flLiteClient.getRetCode());
}
private void failed(String tag, FLLiteClient flLiteClient) {
LOGGER.info(Common.addTag(tag + " failed"));
LOGGER.info(Common.addTag("=========================================the total response of " +
flLiteClient.getIteration() + ": " + curStatus + "========================================="));
Common.freeSession();
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), flLiteClient.getIteration(),
flLiteClient.getRetCode());
}
private static Map<RunType, List<String>> createDatasetMap(String trainDataPath, String evalDataPath,
String inferDataPath, String pathRegex) {
Map<RunType, List<String>> dataMap = new HashMap<>();
if ((trainDataPath == null) || ("null".equals(trainDataPath)) || (trainDataPath.isEmpty())) {
LOGGER.info(Common.addTag("the trainDataPath is null or empty, please check if you are in the case of " +
"noly inference"));
} else {
dataMap.put(RunType.TRAINMODE, Arrays.asList(trainDataPath.split(pathRegex)));
LOGGER.info(Common.addTag("the trainDataPath: " + Arrays.toString(trainDataPath.split(pathRegex))));
}
if ((evalDataPath == null) || ("null".equals(evalDataPath)) || (evalDataPath.isEmpty())) {
LOGGER.info(Common.addTag("the evalDataPath is null or empty, please check if you are in the case of only" +
" trainning without evaluation"));
} else {
dataMap.put(RunType.EVALMODE, Arrays.asList(evalDataPath.split(pathRegex)));
LOGGER.info(Common.addTag("the evalDataPath: " + Arrays.toString(evalDataPath.split(pathRegex))));
}
if ((inferDataPath == null) || ("null".equals(inferDataPath)) || (inferDataPath.isEmpty())) {
LOGGER.info(Common.addTag("the inferDataPath is null or empty, please check if you are in the case of " +
"trainning without inference"));
} else {
dataMap.put(RunType.INFERMODE, Arrays.asList(inferDataPath.split(pathRegex)));
LOGGER.info(Common.addTag("the inferDataPath: " + Arrays.toString(inferDataPath.split(pathRegex))));
}
return dataMap;
}
private static void createWeightNameList(String trainWeightName, String inferWeightName, String nameRegex,
FLParameter flParameter) {
if ((trainWeightName == null) || ("null".equals(trainWeightName)) || (trainWeightName.isEmpty())) {
LOGGER.info(Common.addTag("the trainWeightName is null or empty, only need in " + ServerMod.HYBRID_TRAINING));
} else {
flParameter.setHybridWeightName(Arrays.asList(trainWeightName.split(nameRegex)), RunType.TRAINMODE);
LOGGER.info(Common.addTag("the trainWeightName: " + Arrays.toString(trainWeightName.split(nameRegex))));
}
if ((inferWeightName == null) || ("null".equals(inferWeightName)) || (inferWeightName.isEmpty())) {
LOGGER.info(Common.addTag("the inferWeightName is null or empty, only need in " + ServerMod.HYBRID_TRAINING));
} else {
flParameter.setHybridWeightName(Arrays.asList(inferWeightName.split(nameRegex)), RunType.INFERMODE);
LOGGER.info(Common.addTag("the trainWeightName: " + Arrays.toString(inferWeightName.split(nameRegex))));
}
}
private static void task(String[] args) {
String trainDataPath = args[0];
String evalDataPath = args[1];
String inferDataPath = args[2];
String pathRegex = args[3];
String flName = args[4];
String trainModelPath = args[5];
String inferModelPath = args[6];
String sslProtocol = args[7];
String deployEnv = args[8];
String domainName = args[9];
String certPath = args[10];
boolean useElb = Boolean.parseBoolean(args[11]);
int serverNum = Integer.parseInt(args[12]);
String task = args[13];
int threadNum = Integer.parseInt(args[14]);
String cpuBindMode = args[15];
String trainWeightName = args[16];
String inferWeightName = args[17];
String nameRegex = args[18];
String serverMod = args[19];
int batchSize = Integer.parseInt(args[20]);
FLParameter flParameter = FLParameter.getInstance();
// create dataset of map
Map<RunType, List<String>> dataMap = createDatasetMap(trainDataPath, evalDataPath, inferDataPath, pathRegex);
// create weight name of list
createWeightNameList(trainWeightName, inferWeightName, nameRegex, flParameter);
flParameter.setFlName(flName);
SyncFLJob syncFLJob = new SyncFLJob();
Common.setIsHttps(domainName.split("//")[0].split(":")[0]);
switch (task) {
case "train":
LOGGER.info(Common.addTag("start syncFLJob.flJobRun()"));
if (Common.isHttps()) {
flParameter.setCertPath(certPath);
}
flParameter.setDataMap(dataMap);
flParameter.setTrainModelPath(trainModelPath);
flParameter.setInferModelPath(inferModelPath);
flParameter.setSslProtocol(sslProtocol);
flParameter.setDeployEnv(deployEnv);
flParameter.setDomainName(domainName);
flParameter.setUseElb(useElb);
flParameter.setServerNum(serverNum);
flParameter.setThreadNum(threadNum);
flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode));
flParameter.setBatchSize(batchSize);
syncFLJob.flJobRun();
break;
case "inference":
LOGGER.info(Common.addTag("start syncFLJob.modelInference()"));
flParameter.setDataMap(dataMap);
flParameter.setInferModelPath(inferModelPath);
flParameter.setThreadNum(threadNum);
flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode));
flParameter.setBatchSize(batchSize);
syncFLJob.modelInference();
break;
case "getModel":
LOGGER.info(Common.addTag("start syncFLJob.getModel()"));
if (Common.isHttps()) {
flParameter.setCertPath(certPath);
}
flParameter.setTrainModelPath(trainModelPath);
flParameter.setInferModelPath(inferModelPath);
flParameter.setSslProtocol(sslProtocol);
flParameter.setDeployEnv(deployEnv);
flParameter.setDomainName(domainName);
flParameter.setUseElb(useElb);
flParameter.setServerNum(serverNum);
flParameter.setServerMod(ServerMod.valueOf(serverMod));
syncFLJob.getModel();
break;
default:
LOGGER.info(Common.addTag("do not do any thing!"));
}
}
private static void deprecatedTask(String[] args) {
String trainDataset = args[0];
String vocabFile = args[1];
String idsFile = args[2];
String testDataset = args[3];
String flName = args[4];
String trainModelPath = args[5];
String inferModelPath = args[6];
boolean useSSL = Boolean.parseBoolean(args[7]);
String domainName = args[8];
boolean useElb = Boolean.parseBoolean(args[9]);
int serverNum = Integer.parseInt(args[10]);
String certPath = args[11];
String task = args[12];
FLParameter flParameter = FLParameter.getInstance();
flParameter.setFlName(flName);
SyncFLJob syncFLJob = new SyncFLJob();
Common.setIsHttps(domainName.split("//")[0].split(":")[0]);
switch (task) {
case "train":
if (Common.isHttps()) {
flParameter.setCertPath(certPath);
}
flParameter.setTrainDataset(trainDataset);
flParameter.setTrainModelPath(trainModelPath);
flParameter.setTestDataset(testDataset);
flParameter.setInferModelPath(inferModelPath);
flParameter.setDomainName(domainName);
flParameter.setUseElb(useElb);
flParameter.setServerNum(serverNum);
if (ALBERT.equals(flName)) {
flParameter.setVocabFile(vocabFile);
flParameter.setIdsFile(idsFile);
}
syncFLJob.flJobRun();
break;
case "inference":
flParameter.setTestDataset(testDataset);
flParameter.setInferModelPath(inferModelPath);
if (ALBERT.equals(flName)) {
flParameter.setVocabFile(vocabFile);
flParameter.setIdsFile(idsFile);
}
syncFLJob.modelInference();
break;
case "getModel":
if (Common.isHttps()) {
flParameter.setCertPath(certPath);
}
flParameter.setTrainModelPath(trainModelPath);
flParameter.setInferModelPath(inferModelPath);
flParameter.setDomainName(domainName);
flParameter.setUseElb(useElb);
flParameter.setServerNum(serverNum);
syncFLJob.getModel();
break;
default:
LOGGER.info(Common.addTag("do not do any thing!"));
}
}
private int[] deprecatedModelInference() {
int[] labels = new int[0];
if (flParameter.getFlName().equals(ALBERT)) {
AlInferBert alInferBert = AlInferBert.getInstance();
@ -290,166 +576,27 @@ public class SyncFLJob {
LOGGER.info(Common.addTag("[model inference] inference finish"));
}
return labels;
}
/**
* Obtains the latest model on the cloud.
*
* @return the status code corresponding to the response message.
*/
public FLClientStatus getModel() {
Common.setSecureRandom(Common.getFastSecureRandom());
int tag = 0;
private FLClientStatus deprecatedGetModel() {
localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString());
FLClientStatus status;
try {
if (flParameter.getFlName().equals(ALBERT)) {
localFLParameter.setServerMod(ServerMod.HYBRID_TRAINING.toString());
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " +
flParameter.getTrainModelPath() + " Create Train Session============="));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the " +
"return is -1"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " +
flParameter.getInferModelPath() + " Create inference Session============="));
AlInferBert alInferBert = AlInferBert.getInstance();
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
} else if (flParameter.getFlName().equals(LENET)) {
localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString());
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " +
flParameter.getTrainModelPath() + " Create Train Session============="));
TrainLenet trainLenet = TrainLenet.getInstance();
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
}
if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return " +
"is -1"));
return FLClientStatus.FAILED;
}
FLCommunication flCommunication = FLCommunication.getInstance();
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
GetModel getModelBuf = GetModel.getInstance();
byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), 0);
byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
if (!Common.isSeverReady(message)) {
LOGGER.info(Common.addTag("[getModel] the server is not ready now, need wait some time and request " +
"again"));
status = FLClientStatus.WAIT;
return status;
}
LOGGER.info(Common.addTag("[getModel] get model request success"));
ByteBuffer debugBuffer = ByteBuffer.wrap(message);
ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer);
status = getModelBuf.doResponse(responseDataBuf);
LOGGER.info(Common.addTag("[getModel] success!"));
} catch (Exception ex) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + ex.getMessage()));
status = FLClientStatus.FAILED;
}
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("===========free train session============="));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
SessionUtil.free(alTrainBert.getTrainSession());
LOGGER.info(Common.addTag("===========free inference session============="));
AlInferBert alInferBert = AlInferBert.getInstance();
SessionUtil.free(alInferBert.getTrainSession());
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("===========free session============="));
TrainLenet trainLenet = TrainLenet.getInstance();
SessionUtil.free(trainLenet.getTrainSession());
}
FLLiteClient flLiteClient = new FLLiteClient();
status = flLiteClient.getModel();
return status;
}
private void waitSomeTime() {
if (flParameter.getSleepTime() != 0) {
Common.sleep(flParameter.getSleepTime());
} else {
Common.sleep(SLEEP_TIME);
}
}
private void waitNextReqTime(String nextReqTime) {
long waitTime = Common.getWaitTime(nextReqTime);
Common.sleep(waitTime);
}
private void restart(String tag, String nextReqTime, int iteration, int retcode) {
LOGGER.info(Common.addTag(tag + " out of time: need wait and request startFLJob again"));
waitNextReqTime(nextReqTime);
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), iteration, retcode);
}
private void failed(String tag, int iteration, int retcode, FLClientStatus curStatus) {
LOGGER.info(Common.addTag(tag + " failed"));
LOGGER.info(Common.addTag("=========================================the total response of " +
iteration + ": " + curStatus + "========================================="));
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), iteration, retcode);
}
public static void main(String[] args) {
String trainDataset = args[0];
String vocabFile = args[1];
String idsFile = args[2];
String testDataset = args[3];
String flName = args[4];
String trainModelPath = args[5];
String inferModelPath = args[6];
boolean useSSL = Boolean.parseBoolean(args[7]);
String domainName = args[8];
boolean useElb = Boolean.parseBoolean(args[9]);
int serverNum = Integer.parseInt(args[10]);
String certPath = args[11];
String task = args[12];
FLParameter flParameter = FLParameter.getInstance();
SyncFLJob syncFLJob = new SyncFLJob();
if (task.equals("train")) {
if (useSSL) {
flParameter.setCertPath(certPath);
}
flParameter.setTrainDataset(trainDataset);
flParameter.setFlName(flName);
flParameter.setTrainModelPath(trainModelPath);
flParameter.setTestDataset(testDataset);
flParameter.setInferModelPath(inferModelPath);
flParameter.setUseSSL(useSSL);
flParameter.setDomainName(domainName);
flParameter.setUseElb(useElb);
flParameter.setServerNum(serverNum);
if (ALBERT.equals(flName)) {
flParameter.setVocabFile(vocabFile);
flParameter.setIdsFile(idsFile);
}
syncFLJob.flJobRun();
} else if (task.equals("inference")) {
flParameter.setFlName(flName);
flParameter.setTestDataset(testDataset);
flParameter.setInferModelPath(inferModelPath);
if (ALBERT.equals(flName)) {
flParameter.setVocabFile(vocabFile);
flParameter.setIdsFile(idsFile);
}
syncFLJob.modelInference();
} else if (task.equals("getModel")) {
if (useSSL) {
flParameter.setCertPath(certPath);
}
flParameter.setFlName(flName);
flParameter.setTrainModelPath(trainModelPath);
flParameter.setInferModelPath(inferModelPath);
flParameter.setUseSSL(useSSL);
flParameter.setDomainName(domainName);
flParameter.setUseElb(useElb);
flParameter.setServerNum(serverNum);
syncFLJob.getModel();
} else {
LOGGER.info(Common.addTag("do not do any thing!"));
if (args[4] == null || args[4].isEmpty()) {
LOGGER.severe(Common.addTag("the parameter of <args[4]> is null, please check"));
throw new IllegalArgumentException();
}
if (Common.checkFLName(args[4])) {
deprecatedTask(args);
} else {
task(args);
}
}
}

View File

@ -16,14 +16,16 @@
package com.mindspore.flclient;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.Client;
import com.mindspore.flclient.model.ClientManager;
import com.mindspore.flclient.model.CommonUtils;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.Status;
import com.mindspore.flclient.model.TrainLenet;
import com.mindspore.lite.MSTensor;
import mindspore.schema.FeatureMap;
import mindspore.schema.RequestUpdateModel;
@ -33,9 +35,13 @@ import mindspore.schema.ResponseUpdateModel;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
/**
* Define the serialization method, handle the response message returned from server for updateModel request.
*
@ -52,6 +58,7 @@ public class UpdateModel {
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private FLClientStatus status;
private int retCode = ResponseCode.RequestError;
private UpdateModel() {
}
@ -78,6 +85,10 @@ public class UpdateModel {
return status;
}
public int getRetCode() {
return retCode;
}
/**
* Get a flatBuffer builder of RequestUpdateModel.
*
@ -88,7 +99,16 @@ public class UpdateModel {
*/
public byte[] getRequestUpdateFLJob(int iteration, SecureProtocol secureProtocol, int trainDataSize) {
RequestUpdateModelBuilder builder = new RequestUpdateModelBuilder(localFLParameter.getEncryptLevel());
return builder.flName(flParameter.getFlName()).time().id(localFLParameter.getFlID())
boolean isPkiVerify = flParameter.isPkiVerify();
if (isPkiVerify) {
Date date = new Date();
long timestamp = date.getTime();
String dateTime = String.valueOf(timestamp);
byte[] signature = CipherClient.signTimeAndIter(dateTime, iteration);
return builder.flName(flParameter.getFlName()).time(dateTime).id(localFLParameter.getFlID())
.featuresMap(secureProtocol, trainDataSize).iteration(iteration).signData(signature).build();
}
return builder.flName(flParameter.getFlName()).time("null").id(localFLParameter.getFlID())
.featuresMap(secureProtocol, trainDataSize).iteration(iteration).build();
}
@ -99,8 +119,9 @@ public class UpdateModel {
* @return the status code corresponding to the response message.
*/
public FLClientStatus doResponse(ResponseUpdateModel response) {
LOGGER.info(Common.addTag("[updateModel] ==========updateModel response================"));
LOGGER.info(Common.addTag("[updateModel] ==========retcode: " + response.retcode()));
retCode = response.retcode();
LOGGER.info(Common.addTag("[updateModel] ==========the response message of updateModel is================"));
LOGGER.info(Common.addTag("[updateModel] ==========retCode: " + retCode));
LOGGER.info(Common.addTag("[updateModel] ==========reason: " + response.reason()));
LOGGER.info(Common.addTag("[updateModel] ==========next request time: " + response.nextReqTime()));
switch (response.retcode()) {
@ -120,6 +141,65 @@ public class UpdateModel {
}
}
private Map<String, float[]> getFeatureMap() {
Client client = ClientManager.getClient(flParameter.getFlName());
status = Common.initSession(flParameter.getTrainModelPath());
if (status == FLClientStatus.FAILED) {
retCode = ResponseCode.RequestError;
throw new IllegalArgumentException();
}
List<MSTensor> features = client.getFeatures();
Map<String, float[]> trainedMap = CommonUtils.convertTensorToFeatures(features);
LOGGER.info(Common.addTag("[updateModel] ===========free session============="));
Common.freeSession();
if (trainedMap.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the return trainedMap is empty in <CommonUtils" +
".convertTensorToFeatures>"));
retCode = ResponseCode.RequestError;
status = FLClientStatus.FAILED;
throw new IllegalArgumentException();
}
return trainedMap;
}
private Map<String, float[]> deprecatedGetFeatureMap() {
status = Common.initSession(flParameter.getTrainModelPath());
if (status == FLClientStatus.FAILED) {
retCode = ResponseCode.RequestError;
throw new IllegalArgumentException();
}
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " +
flParameter.getFlName()));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
if (map.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil" +
".convertTensorToFeatures>"));
status = FLClientStatus.FAILED;
throw new IllegalArgumentException();
}
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " +
flParameter.getFlName()));
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
if (map.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil" +
".convertTensorToFeatures>"));
status = FLClientStatus.FAILED;
throw new IllegalArgumentException();
}
} else {
LOGGER.severe(Common.addTag("[updateModel] the flName is not valid"));
status = FLClientStatus.FAILED;
throw new IllegalArgumentException();
}
Common.freeSession();
return map;
}
class RequestUpdateModelBuilder {
private RequestUpdateModel requestUM;
private FlatBufferBuilder builder;
@ -127,6 +207,7 @@ public class UpdateModel {
private int nameOffset = 0;
private int idOffset = 0;
private int timestampOffset = 0;
private int signDataOffset = 0;
private int iteration = 0;
private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT;
@ -153,12 +234,22 @@ public class UpdateModel {
/**
* Serialize the element timestamp in RequestUpdateModel.
*
* @param setTime current timestamp when the request is sent.
* @return the RequestUpdateModelBuilder object.
*/
private RequestUpdateModelBuilder time() {
Date date = new Date();
long time = date.getTime();
this.timestampOffset = builder.createString(String.valueOf(time));
private RequestUpdateModelBuilder time(String setTime) {
if (setTime == null || setTime.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the parameter of <setTime> is null or empty, please " +
"check!"));
throw new IllegalArgumentException();
}
if (setTime.equals("null")) {
Date date = new Date();
long time = date.getTime();
this.timestampOffset = builder.createString(String.valueOf(time));
} else {
this.timestampOffset = builder.createString(setTime);
}
return this;
}
@ -189,10 +280,16 @@ public class UpdateModel {
}
private RequestUpdateModelBuilder featuresMap(SecureProtocol secureProtocol, int trainDataSize) {
ArrayList<String> encryptFeatureName = secureProtocol.getEncryptFeatureName();
ArrayList<String> updateFeatureName = secureProtocol.getUpdateFeatureName();
Map<String, float[]> trainedMap = new HashMap<String, float[]>();
if (Common.checkFLName(flParameter.getFlName())) {
trainedMap = deprecatedGetFeatureMap();
} else {
trainedMap = getFeatureMap();
}
switch (encryptLevel) {
case PW_ENCRYPT:
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize);
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize, trainedMap);
if (fmOffsetsPW == null || fmOffsetsPW.length == 0) {
LOGGER.severe("[Encrypt] the return fmOffsetsPW from <secureProtocol.pwMaskModel> is " +
"null, please check");
@ -202,10 +299,12 @@ public class UpdateModel {
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
return this;
case DP_ENCRYPT:
int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize);
int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize, trainedMap);
if (fmOffsetsDP == null || fmOffsetsDP.length == 0) {
LOGGER.severe("[Encrypt] the return fmOffsetsDP from <secureProtocol.dpMaskModel> is " +
"null, please check");
retCode = ResponseCode.RequestError;
status = FLClientStatus.FAILED;
throw new IllegalArgumentException();
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP);
@ -213,36 +312,11 @@ public class UpdateModel {
return this;
case NOT_ENCRYPT:
default:
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " +
flParameter.getFlName()));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
if (map.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil" +
".convertTensorToFeatures>"));
status = FLClientStatus.FAILED;
}
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " +
flParameter.getFlName()));
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
if (map.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil" +
".convertTensorToFeatures>"));
status = FLClientStatus.FAILED;
}
} else {
LOGGER.severe(Common.addTag("[updateModel] the flName is not valid"));
throw new IllegalArgumentException();
}
int featureSize = encryptFeatureName.size();
int featureSize = updateFeatureName.size();
int[] fmOffsets = new int[featureSize];
for (int i = 0; i < featureSize; i++) {
String key = encryptFeatureName.get(i);
float[] data = map.get(key);
String key = updateFeatureName.get(i);
float[] data = trainedMap.get(key);
LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature " +
"size: " + data.length));
for (int j = 0; j < data.length; j++) {
@ -258,6 +332,22 @@ public class UpdateModel {
}
}
/**
* Serialize the element signature in RequestUpdateModel.
*
* @param signData the signature Data.
* @return the RequestUpdateModelBuilder object.
*/
private RequestUpdateModelBuilder signData(byte[] signData) {
if (signData == null || signData.length == 0) {
LOGGER.severe(Common.addTag("[updateModel] the parameter of <signData> is null or empty, please " +
"check!"));
throw new IllegalArgumentException();
}
this.signDataOffset = RequestUpdateModel.createSignatureVector(builder, signData);
return this;
}
/**
* Create a flatBuffer builder of RequestUpdateModel.
*
@ -270,6 +360,7 @@ public class UpdateModel {
RequestUpdateModel.addTimestamp(builder, this.timestampOffset);
RequestUpdateModel.addIteration(builder, this.iteration);
RequestUpdateModel.addFeatureMap(builder, this.fmOffset);
RequestUpdateModel.addSignature(builder, this.signDataOffset);
int root = RequestUpdateModel.endRequestUpdateModel(builder);
builder.finish(root);
return builder.sizedByteArray();

View File

@ -21,7 +21,6 @@ import static com.mindspore.flclient.LocalFLParameter.KEY_LEN;
import com.mindspore.flclient.Common;
import java.io.UnsupportedEncodingException;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;

View File

@ -0,0 +1,411 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher;
import com.mindspore.flclient.Common;
import com.mindspore.flclient.FLParameter;
import org.bouncycastle.asn1.ASN1OctetString;
import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier;
import org.bouncycastle.asn1.x509.SubjectKeyIdentifier;
import org.bouncycastle.util.encoders.Hex;
import java.io.ByteArrayInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.security.InvalidKeyException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.PublicKey;
import java.security.SignatureException;
import java.security.UnrecoverableEntryException;
import java.security.cert.CRLException;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509CRL;
import java.security.cert.X509CRLEntry;
import java.security.cert.X509Certificate;
import java.util.Base64;
import java.util.Date;
import java.util.Set;
import java.util.logging.Logger;
/**
* Certificate verification class
*
* @since 2021-8-27
*/
public class CertVerify {
private static final Logger LOGGER = Logger.getLogger(CertVerify.class.toString());
/**
* Verify the legitimacy of certificate chain
*
* @param clientID clientID of this client
* @param x509Certificates certificate chain
* @return verification result
*/
public static boolean verifyCertificateChain(String clientID, X509Certificate[] x509Certificates) {
if (clientID == null || clientID.isEmpty()) {
LOGGER.severe(Common.addTag("[CertVerify] the parameter clientID is null or empty, please check!"));
return false;
}
if (x509Certificates == null || x509Certificates.length < 2) {
LOGGER.severe(Common.addTag("[CertVerify] the parameter x509Certificates is null or the length is not " +
"valid: < 2, please check!"));
return false;
}
if (verifyChain(clientID, x509Certificates) && verifyCommonName(clientID, x509Certificates)
&& verifyCrl(clientID, x509Certificates) && verifyValidDate(x509Certificates) &&
verifyKeyIdentifier(clientID, x509Certificates)) {
LOGGER.info(Common.addTag("[CertVerify] verifyCertificateChain success!"));
return true;
}
LOGGER.severe(Common.addTag("[CertVerify] verifyCertificateChain failed!"));
return false;
}
private static boolean verifyCommonName(String clientID, X509Certificate[] x509Certificate) {
if (clientID == null || clientID.isEmpty()) {
LOGGER.severe(Common.addTag("[CertVerify] the parameter clientID is null or empty, please check!"));
return false;
}
if (x509Certificate == null || x509Certificate.length < 2) {
LOGGER.severe(Common.addTag("[CertVerify] x509Certificate chains is null or the length is not valid: < 2," +
" please check!"));
return false;
}
X509Certificate[] certificateChains = getX509CertificateChain(clientID);
if (certificateChains == null || certificateChains.length < 4) {
LOGGER.severe(Common.addTag("[CertVerify] certificateChains is null or the length is not valid: < 4, " +
"please check!"));
return false;
}
X509Certificate localEquipCACert = certificateChains[2];
// get subjectDN of local root equipment CA certificate
String localEquipCAName = localEquipCACert.getSubjectDN().getName();
// get issueDN of client's equipment certificate
X509Certificate remoteEquipCert = x509Certificate[1];
String equipIssueName = remoteEquipCert.getIssuerDN().getName();
return localEquipCAName.equals(equipIssueName);
}
// check whether the former certificate owner is the publisher of next one.
private static boolean verifyChain(String clientID, X509Certificate[] x509Certificates) {
if (x509Certificates == null || x509Certificates.length < 2) {
LOGGER.severe(Common.addTag("[CertVerify] certificateChains is null or the length is not valid: < 2, " +
"please check!"));
return false;
}
// check remote equipment certificate
try {
X509Certificate[] certificateChains = getX509CertificateChain(clientID);
if (certificateChains == null || certificateChains.length < 3) {
LOGGER.severe(Common.addTag("[CertVerify] certificateChains is null or the length is not valid: < 3, " +
"please check!"));
return false;
}
X509Certificate localEquipCA = certificateChains[2];
PublicKey publicKey = localEquipCA.getPublicKey();
x509Certificates[1].verify(publicKey);
} catch (NoSuchProviderException | CertificateException | NoSuchAlgorithmException |
InvalidKeyException | SignatureException e) {
LOGGER.severe(Common.addTag("[CertVerify] catch Exception: " + e.getMessage()));
return false;
}
// check remote service certificate
X509Certificate remoteEquipCert = x509Certificates[1];
X509Certificate remoteServiceCert = x509Certificates[0];
try {
remoteEquipCert.checkValidity();
remoteServiceCert.checkValidity();
} catch (java.security.cert.CertificateExpiredException |
java.security.cert.CertificateNotYetValidException e) {
e.printStackTrace();
return false;
}
try {
PublicKey publicKey = remoteEquipCert.getPublicKey();
remoteServiceCert.verify(publicKey);
} catch (CertificateException | NoSuchAlgorithmException | InvalidKeyException |
NoSuchProviderException | SignatureException e) {
LOGGER.severe(Common.addTag("verifyChain failed!"));
LOGGER.severe(Common.addTag("[verifyChain] catch Exception: " + e.getMessage()));
return false;
}
LOGGER.severe(Common.addTag("verifyChain success!"));
return true;
}
/**
* get certificate chain according to clientID
*
* @param clientID clientID of this client
* @return certificate chain
*/
public static X509Certificate[] getX509CertificateChain(String clientID) {
if (clientID == null || clientID.isEmpty()) {
LOGGER.severe(Common.addTag("[CertVerify] the parameter clientID is null or empty, please check!"));
return null;
}
X509Certificate[] x509Certificates = null;
try {
Certificate[] certificates = null;
KeyStore keyStore = KeyStore.getInstance(CipherConsts.KEYSTORE_TYPE);
keyStore.load(null);
KeyStore.Entry entry = keyStore.getEntry(clientID, null);
if (entry == null || !(entry instanceof KeyStore.PrivateKeyEntry)) {
return null;
}
certificates = ((KeyStore.PrivateKeyEntry) entry).getCertificateChain();
if (certificates == null) {
return null;
}
x509Certificates = (X509Certificate[]) certificates;
} catch (IOException | NoSuchAlgorithmException | UnrecoverableEntryException | KeyStoreException |
CertificateException e) {
LOGGER.severe(Common.addTag("[CertVerify] catch Exception: " + e.getMessage()));
}
return x509Certificates;
}
/**
* transform pem format to X509 format
*
* @param pemCerts pem format certificate
* @return X509 format certificates
*/
public static X509Certificate[] transformPemArrayToX509Array(String[] pemCerts) {
if (pemCerts == null || pemCerts.length == 0) {
LOGGER.severe(Common.addTag("[CertVerify] pemCerts is null or empty, please check!"));
throw new IllegalArgumentException();
}
int nSize = pemCerts.length;
X509Certificate[] x509Certificates = new X509Certificate[nSize];
for (int i = 0; i < nSize; ++i) {
x509Certificates[i] = transformPemToX509(pemCerts[i]);
}
return x509Certificates;
}
private static X509Certificate transformPemToX509(String pemCert) {
X509Certificate x509Certificate = null;
CertificateFactory cf;
try {
if (pemCert != null && !pemCert.trim().isEmpty()) {
byte[] certificateData = Base64.getDecoder().decode(pemCert);
cf = CertificateFactory.getInstance("X509");
x509Certificate = (X509Certificate) cf.generateCertificate(new ByteArrayInputStream(certificateData));
}
} catch (CertificateException e) {
LOGGER.severe(Common.addTag("[CertVerify] catch Exception: " + e.getMessage()));
return null;
}
return x509Certificate;
}
private static boolean verifyCrl(String clientID, X509Certificate[] x509Certificates) {
if (x509Certificates == null || x509Certificates.length < 2) {
LOGGER.severe(Common.addTag("[verifyCrl] the number of certificate in x509Certificates is less than 2, " +
"please check!"));
throw new IllegalArgumentException();
}
FLParameter flParameter = FLParameter.getInstance();
X509Certificate equipCert = x509Certificates[1];
if (equipCert == null) {
LOGGER.severe(Common.addTag("[verifyCrl] equipCert is null, please check it!"));
return false;
}
String equipCertSerialNum = equipCert.getSerialNumber().toString();
if (verifySingleCrl(clientID, equipCertSerialNum, flParameter.getEquipCrlPath())) {
LOGGER.info(Common.addTag("[verifyCrl] verify crl certificate success!"));
return true;
}
LOGGER.info(Common.addTag("[verifyCrl] verify crl certificate failed!"));
return false;
}
private static boolean verifySingleCrl(String clientID, String caSerialNumber, String crlPath) {
if (caSerialNumber == null || caSerialNumber.isEmpty()) {
LOGGER.severe(Common.addTag("[CertVerify] caSerialNumber is null or empty, please check!"));
throw new IllegalArgumentException();
}
// crlPath does not exist
if (crlPath.equals("null")) {
LOGGER.severe(Common.addTag("[CertVerify] crlPath is null, please set crlPath with setEquipCrlPath " +
"method!"));
return false;
}
boolean notInFlag = true;
try {
X509CRL crl = (X509CRL) readCrl(crlPath);
if (crl != null) {
// check CRL cert with local equipment CA publicKey
X509Certificate[] certificateChains = getX509CertificateChain(clientID);
if (certificateChains == null || certificateChains.length < 3) {
LOGGER.severe(Common.addTag("[CertVerify] certificateChains is null or the length is not" +
" valid: < 3, please check!"));
return false;
}
X509Certificate localEquipCA = certificateChains[2];
PublicKey publicKey = localEquipCA.getPublicKey();
crl.verify(publicKey);
// check whether remote equipmentCert in CRL
Set<?> set = crl.getRevokedCertificates();
if (set == null) {
LOGGER.info(Common.addTag("[verifySingleCrl] verifyCrl Revoked Cert list is null"));
return true;
}
for (Object obj : set) {
X509CRLEntry crlEntity = (X509CRLEntry) obj;
if (crlEntity.getSerialNumber().toString().equals(caSerialNumber)) {
LOGGER.info(Common.addTag("[verifySingleCrl] Find same SerialNumber during the crl!"));
notInFlag = false;
break;
}
}
}
} catch (java.security.cert.CRLException | java.security.NoSuchAlgorithmException |
java.security.InvalidKeyException | java.security.NoSuchProviderException |
java.security.SignatureException e) {
LOGGER.severe(Common.addTag("[verifySingleCrl] judgeCAInCRL error: " + e.getMessage()));
notInFlag = false;
}
return notInFlag;
}
private static Object readCrl(String assetName) {
if (assetName == null || assetName.isEmpty()) {
LOGGER.severe(Common.addTag("[readCrl] the parameter of <assetName> is null or empty, please check!"));
return null;
}
InputStream inputStream = null;
try {
inputStream = new FileInputStream(assetName);
} catch (IOException e) {
LOGGER.severe(Common.addTag("[readCrl] catch Exception of read inputStream in readCert: " +
e.getMessage()));
return null;
}
Object crlCert = null;
try {
CertificateFactory cf = CertificateFactory.getInstance("X.509");
crlCert = cf.generateCRL(inputStream);
} catch (CertificateException | CRLException e) {
LOGGER.severe(Common.addTag("[readCrl] catch Exception of creating CertificateFactory in readCert: "
+ e.getMessage()));
} finally {
try {
inputStream.close();
} catch (IOException e) {
LOGGER.severe(Common.addTag("[readCrl] catch Exception of close inputStream: "
+ e.getMessage()));
}
}
return crlCert;
}
private static boolean verifyValidDate(X509Certificate[] x509Certificates) {
if (x509Certificates == null) {
LOGGER.severe(Common.addTag("[CertVerify] x509Certificates is null, please check!"));
throw new IllegalArgumentException();
}
Date date = new Date();
try {
int nSize = x509Certificates.length;
for (int i = 0; i < nSize; ++i) {
x509Certificates[i].checkValidity(date);
}
} catch (java.security.cert.CertificateExpiredException |
java.security.cert.CertificateNotYetValidException e) {
LOGGER.severe(Common.addTag("[verifyValidDate] catch Exception: " + e.getMessage()));
return false;
}
return true;
}
private static boolean verifyKeyIdentifier(String clientID, X509Certificate[] x509Certificates) {
if (clientID == null || clientID.isEmpty()) {
LOGGER.severe(Common.addTag("[CertVerify] the parameter clientID is null or empty, please check!"));
return false;
}
if (x509Certificates == null || x509Certificates.length < 2) {
LOGGER.severe(Common.addTag("[CertVerify] x509Certificate chains is null or the length is not valid: < 2," +
" please check!"));
return false;
}
X509Certificate[] certificateChains = getX509CertificateChain(clientID);
if (certificateChains == null || certificateChains.length < 3) {
LOGGER.severe(Common.addTag("[CertVerify] certificateChains is null or the length is not valid: < 3, " +
"please check!"));
return false;
}
X509Certificate localEquipCACert = certificateChains[2];
String subjectIdentifier = "null";
// get subjectKeyIdentifier of local root equipment CA certificate
try {
String subjectIdentifierOid = "2.5.29.14";
byte[] subjectExtendData = localEquipCACert.getExtensionValue(subjectIdentifierOid);
ASN1OctetString asn1OctetString = ASN1OctetString.getInstance(subjectExtendData);
byte[] tmpData = asn1OctetString.getOctets();
SubjectKeyIdentifier subjectKeyIdentifier = SubjectKeyIdentifier.getInstance(tmpData);
byte[] octKeyIdentifier = subjectKeyIdentifier.getKeyIdentifier();
subjectIdentifier = new String(Hex.encode(octKeyIdentifier));
} catch (ExceptionInInitializerError e) {
e.printStackTrace();
}
// get authorityKeyIdentifier of client's equipment certificate
X509Certificate remoteEquipCert = x509Certificates[1];
String authorityIdentifier = "null";
try {
if (remoteEquipCert == null) {
LOGGER.severe(Common.addTag("[CertVerify] remoteEquipCert is null, please check it!"));
return false;
}
String authorityIdentifierOid = "2.5.29.35";
byte[] authExtendData = remoteEquipCert.getExtensionValue(authorityIdentifierOid);
ASN1OctetString asn1OctetString = ASN1OctetString.getInstance(authExtendData);
byte[] tmpData = asn1OctetString.getOctets();
AuthorityKeyIdentifier authorityKeyIdentifier = AuthorityKeyIdentifier.getInstance(tmpData);
byte[] octKeyIdentifier = authorityKeyIdentifier.getKeyIdentifier();
authorityIdentifier = new String(Hex.encode(octKeyIdentifier));
} catch (ExceptionInInitializerError e) {
e.printStackTrace();
}
if (authorityIdentifier.equals("null") || subjectIdentifier.equals("null")) {
LOGGER.severe(Common.addTag("[CertVerify] authorityKeyIdentifier or subjectKeyIdentifier is null, check " +
"failed!"));
return false;
} else {
return authorityIdentifier.equals(subjectIdentifier);
}
}
}

View File

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher;
/**
* Consts used for certification signature
*
* @since 2021-8-27
*/
public class CipherConsts {
/**
* provider name
*/
public static final String PROVIDER_NAME = "HwUniversalKeyStoreProvider";
/**
* keyStore type
*/
public static final String KEYSTORE_TYPE = "HwKeyStore";
/**
* sign algorithm
*/
public static final String SIGN_ALGORITHM = "SHA256withRSA/PSS";
}

View File

@ -21,6 +21,7 @@ import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.CipherClient;
import com.mindspore.flclient.Common;
import com.mindspore.flclient.FLClientStatus;
import com.mindspore.flclient.FLCommunication;
@ -91,12 +92,28 @@ public class ClientListReq {
long timestamp = date.getTime();
String dateTime = String.valueOf(timestamp);
int time = builder.createString(dateTime);
GetClientList.startGetClientList(builder);
GetClientList.addFlId(builder, id);
GetClientList.addIteration(builder, iteration);
GetClientList.addTimestamp(builder, time);
int clientListRoot = GetClientList.endGetClientList(builder);
int clientListRoot;
byte[] signature = CipherClient.signTimeAndIter(dateTime, iteration);
if (signature == null) {
LOGGER.severe(Common.addTag("[getClientList] get signature is null!"));
return FLClientStatus.FAILED;
}
if (signature.length > 0) {
int signed = GetClientList.createSignatureVector(builder, signature);
GetClientList.startGetClientList(builder);
GetClientList.addFlId(builder, id);
GetClientList.addIteration(builder, iteration);
GetClientList.addTimestamp(builder, time);
GetClientList.addSignature(builder, signed);
clientListRoot = GetClientList.endGetClientList(builder);
} else {
GetClientList.startGetClientList(builder);
GetClientList.addFlId(builder, id);
GetClientList.addIteration(builder, iteration);
GetClientList.addTimestamp(builder, time);
GetClientList.addSignature(builder, 0);
clientListRoot = GetClientList.endGetClientList(builder);
}
builder.finish(clientListRoot);
byte[] msg = builder.sizedByteArray();
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), flParameter.getDomainName());
@ -107,6 +124,7 @@ public class ClientListReq {
"request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
retCode = ResponseCode.OutOfTime;
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);

View File

@ -20,6 +20,7 @@ import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.CipherClient;
import com.mindspore.flclient.Common;
import com.mindspore.flclient.FLClientStatus;
import com.mindspore.flclient.FLCommunication;
@ -68,7 +69,8 @@ public class ReconstructSecretReq {
*/
public FLClientStatus sendReconstructSecret(List<DecryptShareSecrets> decryptShareSecretsList,
List<String> u3ClientList, int iteration) {
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), flParameter.getDomainName());
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
FlatBufferBuilder builder = new FlatBufferBuilder();
int desFlId = builder.createString(localFLParameter.getFlID());
Date date = new Date();
@ -106,12 +108,28 @@ public class ReconstructSecretReq {
}
int reconstructShareSecrets = SendReconstructSecret.createReconstructSecretSharesVector(builder,
decryptShareList);
SendReconstructSecret.startSendReconstructSecret(builder);
SendReconstructSecret.addFlId(builder, desFlId);
SendReconstructSecret.addReconstructSecretShares(builder, reconstructShareSecrets);
SendReconstructSecret.addIteration(builder, iteration);
SendReconstructSecret.addTimestamp(builder, time);
int reconstructSecretRoot = SendReconstructSecret.endSendReconstructSecret(builder);
int reconstructSecretRoot;
byte[] signature = CipherClient.signTimeAndIter(dateTime, iteration);
if (signature.length > 0) {
int signed = SendReconstructSecret.createSignatureVector(builder, signature);
SendReconstructSecret.startSendReconstructSecret(builder);
SendReconstructSecret.addFlId(builder, desFlId);
SendReconstructSecret.addReconstructSecretShares(builder, reconstructShareSecrets);
SendReconstructSecret.addIteration(builder, iteration);
SendReconstructSecret.addTimestamp(builder, time);
SendReconstructSecret.addSignature(builder, signed);
reconstructSecretRoot = SendReconstructSecret.endSendReconstructSecret(builder);
} else {
SendReconstructSecret.startSendReconstructSecret(builder);
SendReconstructSecret.addFlId(builder, desFlId);
SendReconstructSecret.addReconstructSecretShares(builder, reconstructShareSecrets);
SendReconstructSecret.addIteration(builder, iteration);
SendReconstructSecret.addTimestamp(builder, time);
SendReconstructSecret.addSignature(builder, 0);
reconstructSecretRoot = SendReconstructSecret.endSendReconstructSecret(builder);
}
builder.finish(reconstructSecretRoot);
byte[] msg = builder.sizedByteArray();
try {
@ -121,6 +139,7 @@ public class ReconstructSecretReq {
"time and request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
retCode = ResponseCode.OutOfTime;
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);

View File

@ -0,0 +1,152 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher;
import com.mindspore.flclient.Common;
import java.io.IOException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.Signature;
import java.security.SignatureException;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.logging.Logger;
/**
* class used for sign data and verify data
*
* @since 2021-8-27
*/
public class SignAndVerify {
private static final Logger LOGGER = Logger.getLogger(SignAndVerify.class.toString());
/**
* sign data
*
* @param clientID ID of this client
* @param data data need to be signed
* @return signed data
*/
public static byte[] signData(String clientID, byte[] data) {
if (clientID == null || clientID.isEmpty()) {
LOGGER.severe(Common.addTag("[SignAndVerify] the parameter clientID is null or empty, please check!"));
return null;
}
if (data == null || data.length == 0) {
LOGGER.severe(Common.addTag("[SignAndVerify] the parameter data is null or empty, please check!"));
return null;
}
byte[] signData = null;
try {
KeyStore ks = KeyStore.getInstance(CipherConsts.KEYSTORE_TYPE);
ks.load(null);
Key privateKey = ks.getKey(clientID, null);
if (privateKey == null) {
LOGGER.info("private key is null");
return null;
}
Signature signature = Signature.getInstance(CipherConsts.SIGN_ALGORITHM, CipherConsts.PROVIDER_NAME);
signature.initSign((PrivateKey) privateKey);
signature.update(data);
signData = signature.sign();
} catch (KeyStoreException | CertificateException | NoSuchAlgorithmException | IOException
| UnrecoverableKeyException | NoSuchProviderException | InvalidKeyException
| SignatureException e) {
LOGGER.severe(Common.addTag("[SignAndVerify] catch Exception: " + e.getMessage()));
}
return signData;
}
/**
* verify signature with certifications
*
* @param clientID ID of this client
* @param x509Certificates certificates
* @param data original data
* @param signed signed data
* @return verify result
*/
public static boolean verifySignatureByCert(String clientID, X509Certificate[] x509Certificates, byte[] data,
byte[] signed) {
if (clientID == null || clientID.isEmpty()) {
LOGGER.severe(Common.addTag("[SignAndVerify] the parameter clientID is null or empty, please check!"));
return false;
}
if (x509Certificates == null || x509Certificates.length < 1) {
LOGGER.severe(Common.addTag("[SignAndVerify] the parameter x509Certificates is null or the length is not " +
"valid: < 1, please check!"));
return false;
}
if (data == null || data.length == 0) {
LOGGER.severe(Common.addTag("[SignAndVerify] the parameter data is null or empty, please check!"));
return false;
}
if (signed == null || signed.length == 0) {
LOGGER.severe(Common.addTag("[SignAndVerify] the parameter signed is null or empty, please check!"));
return false;
}
if (!CertVerify.verifyCertificateChain(clientID, x509Certificates)) {
LOGGER.info(Common.addTag("Verify chain failed!"));
return false;
}
LOGGER.info(Common.addTag("Verify chain success!"));
boolean isValid;
try {
if (x509Certificates[0].getPublicKey() == null) {
LOGGER.severe(Common.addTag("[SignAndVerify] get public key failed!"));
return false;
}
PublicKey publicKey = x509Certificates[0].getPublicKey(); // get public key
Signature signature = Signature.getInstance(CipherConsts.SIGN_ALGORITHM);
signature.initVerify(publicKey); // set public key
signature.update(data); // set data
isValid = signature.verify(signed); // verify the consistence between signature and data
} catch (NoSuchAlgorithmException | SignatureException | InvalidKeyException e) {
LOGGER.severe(Common.addTag("[SignAndVerify] catch Exception: " + e.getMessage()));
return false;
}
return isValid;
}
/**
* get hash result of bytes
*
* @param bytes inputs
* @return hash value of bytes
*/
public static byte[] getSHA256(byte[] bytes) {
MessageDigest messageDigest;
byte[] hash = new byte[0];
try {
messageDigest = MessageDigest.getInstance("SHA-256");
hash = messageDigest.digest(bytes);
} catch (NoSuchAlgorithmException e) {
LOGGER.severe(Common.addTag("[PkiUtil] catch NoSuchAlgorithmException: " + e.getMessage()));
}
return hash;
}
}

View File

@ -40,7 +40,8 @@ public class ClientPublicKey {
*/
public String getFlID() {
if (flID == null || flID.isEmpty()) {
LOGGER.severe(Common.addTag("[ClientPublicKey] the parameter of <flID> is null, please set it before use"));
LOGGER.severe(Common.addTag("[ClientPublicKey] the parameter of <flID> is null, please set it before " +
"using"));
throw new IllegalArgumentException();
}
return flID;

View File

@ -38,7 +38,7 @@ public class EncryptShare {
public String getFlID() {
if (flID == null || flID.isEmpty()) {
LOGGER.severe(Common.addTag("[DecryptShareSecrets] the parameter of <flID> is null, please set it before " +
"use"));
"using"));
throw new IllegalArgumentException();
}
return flID;

View File

@ -20,7 +20,6 @@ package com.mindspore.flclient.cipher.struct;
* class used define new array type
*
* @param <T> an array
*
* @since 2021-8-27
*/
public class NewArray<T> {

View File

@ -38,7 +38,7 @@ public class ShareSecret {
*/
public String getFlID() {
if (flID == null || flID.isEmpty()) {
LOGGER.severe(Common.addTag("[ShareSecret] the parameter of <flID> is null, please set it before use"));
LOGGER.severe(Common.addTag("[ShareSecret] the parameter of <flID> is null, please set it before using"));
throw new IllegalArgumentException();
}
return flID;

View File

@ -0,0 +1,39 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*/
package com.mindspore.flclient.pki;
import java.security.cert.Certificate;
/**
* PkiBean entity
*
* @since 2021-08-25
*/
public class PkiBean {
private byte[] signData;
private Certificate[] certificates;
public PkiBean(byte[] signData, Certificate[] certificates) {
this.signData = signData;
this.certificates = certificates;
}
public byte[] getSignData() {
return signData;
}
public void setSignData(byte[] signData) {
this.signData = signData;
}
public Certificate[] getCertificates() {
return certificates;
}
public void setCertificates(Certificate[] certificates) {
this.certificates = certificates;
}
}

View File

@ -0,0 +1,27 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*/
package com.mindspore.flclient.pki;
/**
* Pki consts
*
* @since 2021-08-25
*/
public class PkiConsts {
/**
* PROVIDER_NAME
*/
public static final String PROVIDER_NAME = "HwUniversalKeyStoreProvider";
/**
* KEYSTORE_TYPE
*/
public static final String KEYSTORE_TYPE = "HwKeyStore";
/**
* ALGORITHM
*/
public static final String ALGORITHM = "SHA256withRSA/PSS";
}

View File

@ -0,0 +1,247 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*/
package com.mindspore.flclient.pki;
import com.mindspore.flclient.Common;
import com.mindspore.flclient.LocalFLParameter;
import org.bouncycastle.util.io.pem.PemObject;
import org.bouncycastle.util.io.pem.PemWriter;
import java.io.IOException;
import java.io.StringWriter;
import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.PrivateKey;
import java.security.Signature;
import java.security.SignatureException;
import java.security.UnrecoverableEntryException;
import java.security.UnrecoverableKeyException;
import java.security.cert.Certificate;
import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Locale;
import java.util.logging.Logger;
/**
* Pki Util
*
* @since 2021-08-25
*/
public class PkiUtil {
private static final Logger LOGGER = Logger.getLogger(PkiUtil.class.toString());
/**
* generate PkiBean
*
* @param clientID String
* @param time long
* @return PkiBean
*/
public static PkiBean genPkiBean(String clientID, long time) {
String sourceData = LocalFLParameter.getInstance().getFlID() + " " + time;
byte[] signDataBytes = PkiUtil.signData(clientID,
sourceData.getBytes(StandardCharsets.UTF_8));
Certificate[] certificates = PkiUtil.getCertificateChain(clientID);
return new PkiBean(signDataBytes, certificates);
}
/**
* get str of SHA256
*
* @param str String
* @return hash
*/
public static byte[] getSHA256Str(String str) {
MessageDigest messageDigest;
byte[] hash = new byte[0];
try {
messageDigest = MessageDigest.getInstance("SHA-256");
hash = messageDigest.digest(str.getBytes(StandardCharsets.UTF_8));
} catch (NoSuchAlgorithmException e) {
LOGGER.severe(Common.addTag("[PkiUtil] catch NoSuchAlgorithmException: " + e.getMessage()));
}
return hash;
}
/**
* get Pem format
*
* @param certificate Certificate
* @return String result
* @throws IOException e
*/
public static String getPemFormat(Certificate certificate) throws IOException {
StringWriter writer = new StringWriter();
PemWriter pemWriter = new PemWriter(writer);
try {
pemWriter.writeObject(new PemObject("CERTIFICATE", certificate.getEncoded()));
} catch (IOException | CertificateEncodingException e) {
LOGGER.severe(Common.addTag("[PkiUtil] catch IOException or CertificateEncodingException in getPermFormat: "
+ e.getMessage()));
} finally {
pemWriter.flush();
pemWriter.close();
}
return writer.toString();
}
private static byte[] signData(String clientID, byte[] data) {
byte[] signData = null;
try {
KeyStore ks = KeyStore.getInstance(PkiConsts.KEYSTORE_TYPE);
ks.load(null);
Key privateKey = ks.getKey(clientID, null);
if (privateKey == null) {
return new byte[0];
}
Signature signature = Signature.getInstance(PkiConsts.ALGORITHM, PkiConsts.PROVIDER_NAME);
if (privateKey instanceof PrivateKey) {
signature.initSign((PrivateKey) privateKey);
}
signature.update(data);
signData = signature.sign();
} catch (KeyStoreException | CertificateException | NoSuchAlgorithmException | IOException
| UnrecoverableKeyException | NoSuchProviderException | InvalidKeyException
| SignatureException e) {
LOGGER.severe(Common.addTag("[PkiUtil] catch Exception: " + e.getMessage()));
}
return signData;
}
/**
* get certificate chain
*
* @param clientID String
* @return Certificate[]
*/
public static Certificate[] getCertificateChain(String clientID) {
Certificate[] certificates = null;
try {
KeyStore keyStore = KeyStore.getInstance(PkiConsts.KEYSTORE_TYPE);
keyStore.load(null);
KeyStore.Entry entry = keyStore.getEntry(clientID, null);
if (entry == null) {
return new Certificate[0];
}
if (!(entry instanceof KeyStore.PrivateKeyEntry)) {
return new Certificate[0];
}
certificates = ((KeyStore.PrivateKeyEntry) entry).getCertificateChain();
} catch (IOException | CertificateException | NoSuchAlgorithmException
| UnrecoverableEntryException | KeyStoreException e) {
LOGGER.severe(Common.addTag("[PkiUtil] catch Exception: " + e.getMessage()));
}
return certificates;
}
/**
* to hex format
*
* @param data byte[]
* @return String
*/
public static String toHexFormat(byte[] data) {
if (data == null || data.length == 0) {
return "";
}
StringBuilder sb = new StringBuilder();
for (byte byteData : data) {
sb.append(String.format(Locale.ROOT, "%02x", byteData));
}
return sb.toString();
}
/**
* gen equip cert hash
*
* @param clientID String
* @return String
*/
public static String genEquipCertHash(String clientID) {
String equipCert;
byte[] equipCertBytesHash = null;
try {
Certificate[] certificates = getCertificateChain(clientID);
if (certificates == null || certificates.length < 2) {
return "";
}
equipCert = readPemFormat(certificates[1]);
equipCertBytesHash = getSHA256Str(equipCert);
} catch (IOException e) {
LOGGER.severe(Common.addTag("[PkiUtil] catch Exception: " + e.getMessage()));
}
return toHexFormat(equipCertBytesHash);
}
/**
* generate hash from cer
*
* @param certificateGradeOne X509Certificate
* @return String
*/
public static String genHashFromCer(X509Certificate certificateGradeOne) {
String equipCert = null;
byte[] equipCertBytesHash = null;
try {
equipCert = readPemFormat(certificateGradeOne);
equipCertBytesHash = getSHA256Str(equipCert);
} catch (IOException e) {
LOGGER.severe(Common.addTag("[PkiUtil] catch Exception: " + e.getMessage()));
}
if (equipCertBytesHash == null) {
return "";
}
StringBuilder sb = new StringBuilder();
for (byte byteData : equipCertBytesHash) {
sb.append(String.format(Locale.ROOT, "%02x", byteData));
}
return sb.toString();
}
/**
* read pem format
*
* @param certificate Certificate
* @return String result
* @throws IOException e
*/
public static String readPemFormat(Certificate certificate) throws IOException {
StringWriter writer = new StringWriter();
PemWriter pemWriter = new PemWriter(writer);
if (certificate == null) {
LOGGER.severe(Common.addTag("[PkiUtil] the input parameter certificate is null, please check"));
throw new IllegalArgumentException();
}
try {
pemWriter.writeObject(new PemObject("CERTIFICATE", certificate.getEncoded()));
} catch (IOException | CertificateEncodingException e) {
LOGGER.severe(
Common.addTag("[PkiUtil] catch IOException or CertificateEncodingException in getPermFormat: "
+ e.getMessage()));
} finally {
pemWriter.flush();
pemWriter.close();
}
return writer.toString();
}
}