From cf43399b20225554632bec431a786bdda8dbda55 Mon Sep 17 00:00:00 2001 From: zhoushan Date: Mon, 13 Dec 2021 11:25:07 +0800 Subject: [PATCH] add code of flclient for device decoupling --- .../java/com/mindspore/flclient/BindMode.java | 12 + .../com/mindspore/flclient/CipherClient.java | 609 ++++++++++++++- .../java/com/mindspore/flclient/Common.java | 197 ++++- .../mindspore/flclient/FLCommunication.java | 51 +- .../flclient/FLJobResultCallback.java | 2 + .../com/mindspore/flclient/FLLiteClient.java | 512 ++++++++----- .../com/mindspore/flclient/FLParameter.java | 347 ++++++++- .../java/com/mindspore/flclient/GetModel.java | 171 ++++- .../mindspore/flclient/LocalFLParameter.java | 57 +- .../flclient/SSLSocketFactoryTools.java | 9 +- .../mindspore/flclient/SecureProtocol.java | 79 +- .../com/mindspore/flclient/StartFLJob.java | 696 +++++++++++------ .../com/mindspore/flclient/SyncFLJob.java | 699 +++++++++++------- .../com/mindspore/flclient/UpdateModel.java | 173 ++++- .../mindspore/flclient/cipher/AESEncrypt.java | 1 - .../mindspore/flclient/cipher/CertVerify.java | 411 ++++++++++ .../flclient/cipher/CipherConsts.java | 38 + .../flclient/cipher/ClientListReq.java | 30 +- .../flclient/cipher/ReconstructSecretReq.java | 33 +- .../flclient/cipher/SignAndVerify.java | 152 ++++ .../cipher/struct/ClientPublicKey.java | 3 +- .../flclient/cipher/struct/EncryptShare.java | 2 +- .../flclient/cipher/struct/NewArray.java | 1 - .../flclient/cipher/struct/ShareSecret.java | 2 +- .../com/mindspore/flclient/pki/PkiBean.java | 39 + .../com/mindspore/flclient/pki/PkiConsts.java | 27 + .../com/mindspore/flclient/pki/PkiUtil.java | 247 +++++++ 27 files changed, 3683 insertions(+), 917 deletions(-) create mode 100644 mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/BindMode.java create mode 100644 mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/CertVerify.java create mode 100644 mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/CipherConsts.java create mode 100644 mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/SignAndVerify.java create mode 100644 mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/pki/PkiBean.java create mode 100644 mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/pki/PkiConsts.java create mode 100644 mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/pki/PkiUtil.java diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/BindMode.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/BindMode.java new file mode 100644 index 00000000000..336ed5ec17c --- /dev/null +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/BindMode.java @@ -0,0 +1,12 @@ +package com.mindspore.flclient; + +/** + * The cpu bind mod. + * + * @since 2021-06-30 + */ +public enum BindMode { + NOT_BINDING_CORE, + BIND_LARGE_CORE, + BIND_MIDDLE_CORE +} diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/CipherClient.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/CipherClient.java index 14f8caf2792..c7271702c03 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/CipherClient.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/CipherClient.java @@ -25,33 +25,44 @@ import com.google.flatbuffers.FlatBufferBuilder; import com.mindspore.flclient.cipher.AESEncrypt; import com.mindspore.flclient.cipher.BaseUtil; +import com.mindspore.flclient.cipher.CertVerify; import com.mindspore.flclient.cipher.ClientListReq; import com.mindspore.flclient.cipher.KEYAgreement; import com.mindspore.flclient.cipher.Masking; import com.mindspore.flclient.cipher.ReconstructSecretReq; import com.mindspore.flclient.cipher.ShareSecrets; +import com.mindspore.flclient.cipher.SignAndVerify; import com.mindspore.flclient.cipher.struct.ClientPublicKey; import com.mindspore.flclient.cipher.struct.DecryptShareSecrets; import com.mindspore.flclient.cipher.struct.EncryptShare; import com.mindspore.flclient.cipher.struct.NewArray; import com.mindspore.flclient.cipher.struct.ShareSecret; +import com.mindspore.flclient.pki.PkiUtil; import mindspore.schema.ClientShare; import mindspore.schema.GetExchangeKeys; import mindspore.schema.GetShareSecrets; +import mindspore.schema.RequestAllClientListSign; import mindspore.schema.RequestExchangeKeys; import mindspore.schema.RequestShareSecrets; +import mindspore.schema.ResponseClientListSign; import mindspore.schema.ResponseCode; import mindspore.schema.ResponseExchangeKeys; import mindspore.schema.ResponseShareSecrets; +import mindspore.schema.ReturnAllClientListSign; import mindspore.schema.ReturnExchangeKeys; import mindspore.schema.ReturnShareSecrets; +import mindspore.schema.SendClientListSign; import java.io.IOException; import java.math.BigInteger; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.security.SecureRandom; +import java.security.cert.CertificateEncodingException; +import java.security.cert.X509Certificate; import java.util.ArrayList; +import java.util.Base64; import java.util.Date; import java.util.HashMap; import java.util.List; @@ -93,6 +104,7 @@ public class CipherClient { private ClientListReq clientListReq = new ClientListReq(); private ReconstructSecretReq reconstructSecretReq = new ReconstructSecretReq(); private int retCode; + private Map certificateList = new HashMap(); /** * Construct function of cipherClient @@ -423,12 +435,6 @@ public class CipherClient { " please check!")); return FLClientStatus.FAILED; } - FlatBufferBuilder fbBuilder = new FlatBufferBuilder(); - byte[] cPK = cKey.get(0); - byte[] sPK = sKey.get(0); - int cpk = RequestExchangeKeys.createCPkVector(fbBuilder, cPK); - int spk = RequestExchangeKeys.createSPkVector(fbBuilder, sPK); - byte[] indIv = new byte[I_VEC_LEN]; byte[] pwIv = new byte[I_VEC_LEN]; byte[] thisPwSalt = new byte[SALT_SIZE]; @@ -439,15 +445,24 @@ public class CipherClient { this.individualIv = indIv; this.pwIVec = pwIv; this.pwSalt = thisPwSalt; + FlatBufferBuilder fbBuilder = new FlatBufferBuilder(); + int indIvFbs = RequestExchangeKeys.createIndIvVector(fbBuilder, indIv); int pwIvFbs = RequestExchangeKeys.createPwIvVector(fbBuilder, pwIv); int pwSaltFbs = RequestExchangeKeys.createPwSaltVector(fbBuilder, thisPwSalt); + // for pkiVerify mode + int exchangeKeysRoot; + byte[] cPK = cKey.get(0); + byte[] sPK = sKey.get(0); + int cpk = RequestExchangeKeys.createCPkVector(fbBuilder, cPK); + int spk = RequestExchangeKeys.createSPkVector(fbBuilder, sPK); int id = fbBuilder.createString(localFLParameter.getFlID()); Date date = new Date(); long timestamp = date.getTime(); String dateTime = String.valueOf(timestamp); int time = fbBuilder.createString(dateTime); + String clientID = flParameter.getClientID(); // start build RequestExchangeKeys.startRequestExchangeKeys(fbBuilder); @@ -459,7 +474,24 @@ public class CipherClient { RequestExchangeKeys.addIndIv(fbBuilder, indIvFbs); RequestExchangeKeys.addPwIv(fbBuilder, pwIvFbs); RequestExchangeKeys.addPwSalt(fbBuilder, pwSaltFbs); - int exchangeKeysRoot = RequestExchangeKeys.endRequestExchangeKeys(fbBuilder); + if (flParameter.isPkiVerify()) { + // waiting for certificates to take effect + int waitTakeEffectTime = 5000; + Common.sleep(waitTakeEffectTime); + int nSize = 2; // exchange equipment certificate and service equipment + String[] pemCertificateChains = transformX509ArrayToPemArray(CertVerify.getX509CertificateChain(clientID)); + int[] pemList = new int[nSize]; + for (int i = 0; i < nSize; i++) { + pemList[i] = fbBuilder.createString(pemCertificateChains[i]); + } + int certificatesInt = RequestExchangeKeys.createCertificateChainVector(fbBuilder, pemList); + byte[] signature = signPkAndTime(clientID, cPK, sPK, dateTime, iteration); + int signed = RequestExchangeKeys.createSignatureVector(fbBuilder, signature); + + RequestExchangeKeys.addSignature(fbBuilder, signed); + RequestExchangeKeys.addCertificateChain(fbBuilder, certificatesInt); + } + exchangeKeysRoot = RequestExchangeKeys.endRequestExchangeKeys(fbBuilder); fbBuilder.finish(exchangeKeysRoot); byte[] msg = fbBuilder.sizedByteArray(); String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), @@ -472,6 +504,7 @@ public class CipherClient { "request again")); Common.sleep(SLEEP_TIME); nextRequestTime = ""; + retCode = ResponseCode.OutOfTime; return FLClientStatus.RESTART; } ByteBuffer buffer = ByteBuffer.wrap(responseData); @@ -518,12 +551,30 @@ public class CipherClient { String dateTime = String.valueOf(timestamp); int time = fbBuilder.createString(dateTime); - // start build - GetExchangeKeys.startGetExchangeKeys(fbBuilder); - GetExchangeKeys.addFlId(fbBuilder, id); - GetExchangeKeys.addIteration(fbBuilder, iteration); - GetExchangeKeys.addTimestamp(fbBuilder, time); - int getExchangeKeysRoot = GetExchangeKeys.endGetExchangeKeys(fbBuilder); + int getExchangeKeysRoot; + byte[] signature = signTimeAndIter(dateTime, iteration); + if (signature == null) { + LOGGER.severe(Common.addTag("[getExchangeKeys] get signature is null!")); + return FLClientStatus.FAILED; + } + if (signature.length > 0) { + int signed = GetExchangeKeys.createSignatureVector(fbBuilder, signature); + // start build + GetExchangeKeys.startGetExchangeKeys(fbBuilder); + GetExchangeKeys.addFlId(fbBuilder, id); + GetExchangeKeys.addIteration(fbBuilder, iteration); + GetExchangeKeys.addTimestamp(fbBuilder, time); + GetExchangeKeys.addSignature(fbBuilder, signed); + getExchangeKeysRoot = GetExchangeKeys.endGetExchangeKeys(fbBuilder); + } else { + // start build + GetExchangeKeys.startGetExchangeKeys(fbBuilder); + GetExchangeKeys.addFlId(fbBuilder, id); + GetExchangeKeys.addIteration(fbBuilder, iteration); + GetExchangeKeys.addTimestamp(fbBuilder, time); + getExchangeKeysRoot = GetExchangeKeys.endGetExchangeKeys(fbBuilder); + } + fbBuilder.finish(getExchangeKeysRoot); byte[] msg = fbBuilder.sizedByteArray(); String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), @@ -535,6 +586,7 @@ public class CipherClient { "request again")); Common.sleep(SLEEP_TIME); nextRequestTime = ""; + retCode = ResponseCode.OutOfTime; return FLClientStatus.RESTART; } ByteBuffer buffer = ByteBuffer.wrap(responseData); @@ -558,22 +610,44 @@ public class CipherClient { clientPublicKeyList.clear(); u1ClientList.clear(); int length = bufData.remotePublickeysLength(); + for (int i = 0; i < length; i++) { + ByteBuffer bufCpk = bufData.remotePublickeys(i).cPkAsByteBuffer(); + ByteBuffer bufSpk = bufData.remotePublickeys(i).sPkAsByteBuffer(); + int sizeCpk = bufData.remotePublickeys(i).cPkLength(); + int sizeSpk = bufData.remotePublickeys(i).sPkLength(); + byte[] bufCpkList = byteBufferToList(bufCpk, sizeCpk); + byte[] bufSpkList = byteBufferToList(bufSpk, sizeSpk); + // copy bufCpkList and bufSpkList + byte[] cPkByte = bufCpkList.clone(); + byte[] sPkByte = bufSpkList.clone(); + + // check signature + boolean isPkiVerify = flParameter.isPkiVerify(); + if (isPkiVerify) { + FLClientStatus checkResult = checkSignature(bufData, i, cPkByte, sPkByte); + if (checkResult == FLClientStatus.FAILED) { + return FLClientStatus.FAILED; + } + } + ClientPublicKey publicKey = new ClientPublicKey(); String srcFlId = bufData.remotePublickeys(i).flId(); publicKey.setFlID(srcFlId); - ByteBuffer bufCpk = bufData.remotePublickeys(i).cPkAsByteBuffer(); - int sizeCpk = bufData.remotePublickeys(i).cPkLength(); - ByteBuffer bufSpk = bufData.remotePublickeys(i).sPkAsByteBuffer(); - int sizeSpk = bufData.remotePublickeys(i).sPkLength(); ByteBuffer bufPwIv = bufData.remotePublickeys(i).pwIvAsByteBuffer(); int sizePwIv = bufData.remotePublickeys(i).pwIvLength(); ByteBuffer bufPwSalt = bufData.remotePublickeys(i).pwSaltAsByteBuffer(); int sizePwSalt = bufData.remotePublickeys(i).pwSaltLength(); - publicKey.setCPK(byteToArray(bufCpk, sizeCpk)); - publicKey.setSPK(byteToArray(bufSpk, sizeSpk)); publicKey.setPwIv(byteToArray(bufPwIv, sizePwIv)); publicKey.setPwSalt(byteToArray(bufPwSalt, sizePwSalt)); + NewArray bufCpkArray = new NewArray<>(); + bufCpkArray.setSize(sizeCpk); + bufCpkArray.setArray(bufCpkList); + NewArray bufSpkArray = new NewArray<>(); + bufSpkArray.setSize(sizeSpk); + bufSpkArray.setArray(bufSpkList); + publicKey.setCPK(bufCpkArray); + publicKey.setSPK(bufSpkArray); clientPublicKeyList.put(srcFlId, publicKey); u1ClientList.add(srcFlId); } @@ -598,6 +672,86 @@ public class CipherClient { } } + private FLClientStatus checkSignature(ReturnExchangeKeys bufData, int dataIndex, byte[] cPkByte, byte[] sPkByte) { + ByteBuffer signature = bufData.remotePublickeys(dataIndex).signatureAsByteBuffer(); + byte[] sigByte = new byte[signature.remaining()]; + signature.get(sigByte); + int certifyNum = bufData.remotePublickeys(dataIndex).certificateChainLength(); + String[] pemCerts = new String[certifyNum]; + for (int certIndex = 0; certIndex < certifyNum; certIndex++) { + pemCerts[certIndex] = bufData.remotePublickeys(dataIndex).certificateChain(certIndex); + } + + X509Certificate[] x509Certificates = CertVerify.transformPemArrayToX509Array(pemCerts); + if (x509Certificates.length < 2) { + LOGGER.severe(Common.addTag("the length of x509Certificates is not valid, should be >= 2")); + return FLClientStatus.FAILED; + } + String certificateHash = PkiUtil.genHashFromCer(x509Certificates[1]); + LOGGER.info(Common.addTag("Get certificate hash success!")); + + // check srcId + String srcFlId = bufData.remotePublickeys(dataIndex).flId(); + if (certificateHash.equals(srcFlId)) { + LOGGER.info(Common.addTag("Check flID success and source flID is:" + srcFlId)); + } else { + LOGGER.severe(Common.addTag("Check flID failed!" + "source flID: " + srcFlId + "Hash ID from certificate:" + + " " + certificateHash.equals(srcFlId))); + return FLClientStatus.FAILED; + } + + certificateList.put(srcFlId, x509Certificates); + String timestamp = bufData.remotePublickeys(dataIndex).timestamp(); + String clientID = flParameter.getClientID(); + if (!verifySignature(clientID, x509Certificates, sigByte, cPkByte, sPkByte, timestamp, iteration)) { + LOGGER.info(Common.addTag("[PairWiseMask] FlID: " + srcFlId + + ", signature authentication failed")); + return FLClientStatus.FAILED; + } else { + LOGGER.info(Common.addTag("[PairWiseMask] Verify signature success!")); + } + + // check iteration and timestamp + int remoteIter = bufData.iteration(); + FLClientStatus iterTimeCheck = checkIterAndTimestamp(remoteIter, timestamp); + if (iterTimeCheck == FLClientStatus.FAILED) { + return FLClientStatus.FAILED; + } + + return FLClientStatus.SUCCESS; + } + + private FLClientStatus checkIterAndTimestamp(int remoteIter, String timestamp) { + if (remoteIter != iteration) { + LOGGER.severe(Common.addTag("[PairWiseMask] iteration check failed. Remote iteration of client: " + "is " + + remoteIter + ", which is not consistent with current iteration:" + iteration)); + return FLClientStatus.FAILED; + } + Date date = new Date(); + long currentTimeStamp = date.getTime(); + if (timestamp == null) { + LOGGER.severe(Common.addTag("[PairWiseMask] Received timeStamp is null,please check it!")); + return FLClientStatus.FAILED; + } + long remoteTimeStamp = Long.parseLong(timestamp); + long validIterInterval = flParameter.getValidInterval(); + if (currentTimeStamp - remoteTimeStamp > validIterInterval || currentTimeStamp < remoteTimeStamp) { + LOGGER.severe(Common.addTag("[PairWiseMask] timeStamp check failed! The difference between" + + " remote timestamp and current timestamp is beyond valid iteration interval!")); + return FLClientStatus.FAILED; + } + return FLClientStatus.SUCCESS; + } + + private byte[] byteBufferToList(ByteBuffer buf, int size) { + byte[] array = new byte[size]; + for (int i = 0; i < size; i++) { + byte word = buf.get(); + array[i] = word; + } + return array; + } + private FLClientStatus requestShareSecrets() { FLClientStatus status = genIndividualSecret(); if (status == FLClientStatus.FAILED) { @@ -642,25 +796,44 @@ public class CipherClient { } int encryptedSharesFbs = RequestShareSecrets.createEncryptedSharesVector(fbBuilder, add); - // start build - RequestShareSecrets.startRequestShareSecrets(fbBuilder); - RequestShareSecrets.addFlId(fbBuilder, id); - RequestShareSecrets.addEncryptedShares(fbBuilder, encryptedSharesFbs); - RequestShareSecrets.addIteration(fbBuilder, iteration); - RequestShareSecrets.addTimestamp(fbBuilder, time); - int requestShareSecretsRoot = RequestShareSecrets.endRequestShareSecrets(fbBuilder); + int requestShareSecretsRoot; + byte[] signature = signTimeAndIter(dateTime, iteration); + if (signature == null) { + LOGGER.severe(Common.addTag("[PairWiseMask] get signature is null!")); + return FLClientStatus.FAILED; + } + if (signature.length > 0) { + int signed = RequestShareSecrets.createSignatureVector(fbBuilder, signature); + // start build + RequestShareSecrets.startRequestShareSecrets(fbBuilder); + RequestShareSecrets.addFlId(fbBuilder, id); + RequestShareSecrets.addEncryptedShares(fbBuilder, encryptedSharesFbs); + RequestShareSecrets.addIteration(fbBuilder, iteration); + RequestShareSecrets.addTimestamp(fbBuilder, time); + RequestShareSecrets.addSignature(fbBuilder, signed); + requestShareSecretsRoot = RequestShareSecrets.endRequestShareSecrets(fbBuilder); + } else { + // start build + RequestShareSecrets.startRequestShareSecrets(fbBuilder); + RequestShareSecrets.addFlId(fbBuilder, id); + RequestShareSecrets.addEncryptedShares(fbBuilder, encryptedSharesFbs); + RequestShareSecrets.addIteration(fbBuilder, iteration); + RequestShareSecrets.addTimestamp(fbBuilder, time); + requestShareSecretsRoot = RequestShareSecrets.endRequestShareSecrets(fbBuilder); + } + fbBuilder.finish(requestShareSecretsRoot); byte[] msg = fbBuilder.sizedByteArray(); - String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), - flParameter.getDomainName()); try { + String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), + flParameter.getDomainName()); byte[] responseData = flCommunication.syncRequest(url + "/shareSecrets", msg); if (!Common.isSeverReady(responseData)) { LOGGER.info(Common.addTag("[requestShareSecrets] the server is not ready now, need wait some time" + - " " + - "and request again")); + " and request again")); Common.sleep(SLEEP_TIME); nextRequestTime = ""; + retCode = ResponseCode.OutOfTime; return FLClientStatus.RESTART; } ByteBuffer buffer = ByteBuffer.wrap(responseData); @@ -708,11 +881,27 @@ public class CipherClient { String dateTime = String.valueOf(timestamp); int time = fbBuilder.createString(dateTime); - GetShareSecrets.startGetShareSecrets(fbBuilder); - GetShareSecrets.addFlId(fbBuilder, id); - GetShareSecrets.addIteration(fbBuilder, iteration); - GetShareSecrets.addTimestamp(fbBuilder, time); - int getShareSecrets = GetShareSecrets.endGetShareSecrets(fbBuilder); + int getShareSecrets; + byte[] signature = signTimeAndIter(dateTime, iteration); + if (signature == null) { + LOGGER.severe(Common.addTag("[getShareSecrets] get signature is null!")); + return FLClientStatus.FAILED; + } + if (signature.length > 0) { + int signed = GetShareSecrets.createSignatureVector(fbBuilder, signature); + GetShareSecrets.startGetShareSecrets(fbBuilder); + GetShareSecrets.addFlId(fbBuilder, id); + GetShareSecrets.addIteration(fbBuilder, iteration); + GetShareSecrets.addTimestamp(fbBuilder, time); + GetShareSecrets.addSignature(fbBuilder, signed); + getShareSecrets = GetShareSecrets.endGetShareSecrets(fbBuilder); + } else { + GetShareSecrets.startGetShareSecrets(fbBuilder); + GetShareSecrets.addFlId(fbBuilder, id); + GetShareSecrets.addIteration(fbBuilder, iteration); + GetShareSecrets.addTimestamp(fbBuilder, time); + getShareSecrets = GetShareSecrets.endGetShareSecrets(fbBuilder); + } fbBuilder.finish(getShareSecrets); byte[] msg = fbBuilder.sizedByteArray(); String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), @@ -724,6 +913,7 @@ public class CipherClient { "request again")); Common.sleep(SLEEP_TIME); nextRequestTime = ""; + retCode = ResponseCode.OutOfTime; return FLClientStatus.RESTART; } ByteBuffer buffer = ByteBuffer.wrap(responseData); @@ -864,6 +1054,19 @@ public class CipherClient { } retCode = clientListReq.getRetCode(); + // clientListCheck + if (flParameter.isPkiVerify()) { + LOGGER.info(Common.addTag("[PairWiseMask] The mode is pkiVerify mode, start clientList check ...")); + curStatus = clientListCheck(); + while (curStatus == FLClientStatus.WAIT) { + Common.sleep(SLEEP_TIME); + curStatus = clientListCheck(); + } + if (curStatus != FLClientStatus.SUCCESS) { + return curStatus; + } + } + // SendReconstructSecret curStatus = reconstructSecretReq.sendReconstructSecret(decryptShareSecretsList, u3ClientList, iteration); while (curStatus == FLClientStatus.WAIT) { @@ -876,4 +1079,340 @@ public class CipherClient { retCode = reconstructSecretReq.getRetCode(); return curStatus; } + + private byte[] signPkAndTime(String clientID, byte[] cPK, byte[] sPK, String time, int iterNum) { + // concatenate cPK, sPK and time + byte[] concatData = concatenateData(cPK, sPK, time, iterNum); + LOGGER.info("concatenate data success!"); + // signature + return SignAndVerify.signData(clientID, concatData); + } + + private static byte[] concatenateData(byte[] cPK, byte[] sPK, String time, int iterNum) { + // concatenate cPK, sPK and time + if (time == null) { + LOGGER.severe(Common.addTag("[concatenateData] input time is null, please check!")); + throw new IllegalArgumentException(); + } + byte[] byteTime = time.getBytes(StandardCharsets.UTF_8); + String iterString = String.valueOf(iterNum); + byte[] byteIter = iterString.getBytes(StandardCharsets.UTF_8); + int concatLength = cPK.length + sPK.length + byteTime.length + byteIter.length; + byte[] concatData = new byte[concatLength]; + + int offset = 0; + System.arraycopy(cPK, 0, concatData, offset, cPK.length); + + offset += cPK.length; + System.arraycopy(sPK, 0, concatData, offset, sPK.length); + + offset += sPK.length; + System.arraycopy(byteTime, 0, concatData, offset, byteTime.length); + + offset += byteTime.length; + System.arraycopy(byteIter, 0, concatData, offset, byteIter.length); + return concatData; + } + + private static byte[] concatenateIterAndTime(String time, int iterNum) { + // concatenate cPK, sPK and time + byte[] byteTime = time.getBytes(StandardCharsets.UTF_8); + String iterString = String.valueOf(iterNum); + byte[] byteIter = iterString.getBytes(StandardCharsets.UTF_8); + int concatLength = byteTime.length + byteIter.length; + byte[] concatData = new byte[concatLength]; + int offset = 0; + System.arraycopy(byteTime, 0, concatData, offset, byteTime.length); + offset += byteTime.length; + System.arraycopy(byteIter, 0, concatData, offset, byteIter.length); + return concatData; + } + + private static boolean verifySignature(String clientID, X509Certificate[] x509Certificates, byte[] signature, + byte[] cPK, byte[] sPK, String timestamp, int iteration) { + byte[] concatData = concatenateData(cPK, sPK, timestamp, iteration); + return SignAndVerify.verifySignatureByCert(clientID, x509Certificates, concatData, signature); + } + + private FLClientStatus clientListCheck() { + LOGGER.info(Common.addTag("[PairWiseMask] ==================== ClientListCheck ======================")); + FLClientStatus curStatus; + // send signed clientList + + curStatus = sendClientListSign(); + while (curStatus == FLClientStatus.WAIT) { + Common.sleep(SLEEP_TIME); + curStatus = sendClientListSign(); + } + if (curStatus != FLClientStatus.SUCCESS) { + return curStatus; + } + + // get signed clientList + + curStatus = getAllClientListSign(); + while (curStatus == FLClientStatus.WAIT) { + Common.sleep(SLEEP_TIME); + curStatus = getAllClientListSign(); + } + return curStatus; + } + + private FLClientStatus sendClientListSign() { + LOGGER.info(Common.addTag("[PairWiseMask] ==============request flID: " + + localFLParameter.getFlID() + "==============")); + genDHKeyPairs(); + List clientList = u3ClientList; + int listSize = u3ClientList.size(); + if (listSize == 0) { + LOGGER.severe("[Encrypt] u3List is empty, please check!"); + return FLClientStatus.FAILED; + } + + // send signature + byte[] clientListByte = transStringListToByte(clientList); + byte[] listHash = SignAndVerify.getSHA256(clientListByte); + String clientID = flParameter.getClientID(); + byte[] signature = SignAndVerify.signData(clientID, listHash); + if (signature == null) { + LOGGER.severe(Common.addTag("[sendClientListSign] the returned signature is null")); + return FLClientStatus.FAILED; + } + FlatBufferBuilder fbBuilder = new FlatBufferBuilder(); + int signed = RequestExchangeKeys.createSignatureVector(fbBuilder, signature); + + int sendClientListRoot; + Date date = new Date(); + long timestamp = date.getTime(); + String dateTime = String.valueOf(timestamp); + byte[] reqSign = signTimeAndIter(dateTime, iteration); + String flID = localFLParameter.getFlID(); + int id = fbBuilder.createString(flID); + int time = fbBuilder.createString(dateTime); + if (signature.length > 0) { + int reqSigned = SendClientListSign.createSignatureVector(fbBuilder, reqSign); + sendClientListRoot = SendClientListSign.createSendClientListSign(fbBuilder, id, iteration, time, signed, + reqSigned); + } else { + sendClientListRoot = SendClientListSign.createSendClientListSign(fbBuilder, id, iteration, time, signed, 0); + } + + fbBuilder.finish(sendClientListRoot); + byte[] msg = fbBuilder.sizedByteArray(); + String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), + flParameter.getDomainName()); + try { + byte[] responseData = flCommunication.syncRequest(url + "/pushListSign", msg); + + if (!Common.isSeverReady(responseData)) { + LOGGER.info(Common.addTag("[sendClientListSign] the server is not ready now, need wait some time and " + + "request again")); + Common.sleep(SLEEP_TIME); + nextRequestTime = ""; + retCode = ResponseCode.OutOfTime; + return FLClientStatus.RESTART; + } + ByteBuffer buffer = ByteBuffer.wrap(responseData); + ResponseClientListSign responseClientListSign = + ResponseClientListSign.getRootAsResponseClientListSign(buffer); + return judgeRequestClientList(responseClientListSign); + } catch (IOException e) { + e.printStackTrace(); + return FLClientStatus.FAILED; + } + } + + private FLClientStatus judgeRequestClientList(ResponseClientListSign bufData) { + retCode = bufData.retcode(); + LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestClientListSign**************")); + LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode)); + LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason())); + LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration())); + LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime())); + switch (retCode) { + case (ResponseCode.SUCCEED): + LOGGER.info(Common.addTag("[PairWiseMask] RequestClientListSign success")); + return FLClientStatus.SUCCESS; + case (ResponseCode.OutOfTime): + LOGGER.info(Common.addTag("[PairWiseMask] RequestClientListSign out of time: need wait and request " + + "startFLJob again")); + setNextRequestTime(bufData.nextReqTime()); + return FLClientStatus.RESTART; + case (ResponseCode.RequestError): + case (ResponseCode.SystemError): + LOGGER.info(Common.addTag("[PairWiseMask] catch RequestError or SystemError in RequestClientListSign")); + return FLClientStatus.FAILED; + default: + LOGGER.severe(Common.addTag("[PairWiseMask] the return from server in RequestClientListSign" + + " is invalid: " + retCode)); + return FLClientStatus.FAILED; + } + } + + private FLClientStatus getAllClientListSign() { + FlatBufferBuilder fbBuilder = new FlatBufferBuilder(); + int id = fbBuilder.createString(localFLParameter.getFlID()); + Date date = new Date(); + long timestamp = date.getTime(); + String dateTime = String.valueOf(timestamp); + int time = fbBuilder.createString(dateTime); + + int requestAllClientListSign; + byte[] signature = signTimeAndIter(dateTime, iteration); + if (signature.length > 0) { + int signed = RequestAllClientListSign.createSignatureVector(fbBuilder, signature); + requestAllClientListSign = RequestAllClientListSign.createRequestAllClientListSign(fbBuilder, id, + iteration, time, signed); + } else { + requestAllClientListSign = RequestAllClientListSign.createRequestAllClientListSign(fbBuilder, id, + iteration, time, 0); + } + + fbBuilder.finish(requestAllClientListSign); + byte[] msg = fbBuilder.sizedByteArray(); + String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), + flParameter.getDomainName()); + try { + byte[] responseData = flCommunication.syncRequest(url + "/getListSign", msg); + + if (!Common.isSeverReady(responseData)) { + LOGGER.info(Common.addTag("[getAllClientListSign] the server is not ready now, need wait some time " + + "and request again")); + Common.sleep(SLEEP_TIME); + nextRequestTime = ""; + retCode = ResponseCode.OutOfTime; + return FLClientStatus.RESTART; + } + ByteBuffer buffer = ByteBuffer.wrap(responseData); + ReturnAllClientListSign returnAllClientList = + ReturnAllClientListSign.getRootAsReturnAllClientListSign(buffer); + return judgeAllClientList(returnAllClientList); + } catch (IOException e) { + e.printStackTrace(); + return FLClientStatus.FAILED; + } + } + + private FLClientStatus judgeAllClientList(ReturnAllClientListSign bufData) { + retCode = bufData.retcode(); + LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetAllClientsList**************")); + LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode)); + LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason())); + LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration())); + LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime())); + switch (retCode) { + case (ResponseCode.SUCCEED): + LOGGER.info(Common.addTag("[PairWiseMask] GetAllClientList success")); + int length = bufData.clientListSignLength(); + String clientID = flParameter.getClientID(); + String localFlID = localFLParameter.getFlID(); + byte[] localClientList = transStringListToByte(u3ClientList); + byte[] localListHash = SignAndVerify.getSHA256(localClientList); + for (int i = 0; i < length; i++) { + // verify signature + ByteBuffer signature = bufData.clientListSign(i).signatureAsByteBuffer(); + byte[] sigByte = new byte[signature.remaining()]; + signature.get(sigByte); + if (bufData.clientListSign(i).flId() == null) { + LOGGER.severe(Common.addTag("[PairWiseMask] get flID failed!")); + return FLClientStatus.FAILED; + } + String srcFlId = bufData.clientListSign(i).flId(); + X509Certificate[] remoteCertificates = certificateList.get(srcFlId); + if (localFlID.equals(srcFlId)) { + continue; + } // Do not verify itself + if (!SignAndVerify.verifySignatureByCert(clientID, remoteCertificates, localListHash, sigByte)) { + LOGGER.info(Common.addTag("[PairWiseMask] FlID: " + srcFlId + + ", signature authentication failed")); + return FLClientStatus.FAILED; + } + } + return FLClientStatus.SUCCESS; + case (ResponseCode.SucNotReady): + LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request " + + "GetAllClientsList again!")); + return FLClientStatus.WAIT; + case (ResponseCode.OutOfTime): + LOGGER.info(Common.addTag("[PairWiseMask] GetAllClientsList out of time: need wait and request " + + "startFLJob again")); + setNextRequestTime(bufData.nextReqTime()); + return FLClientStatus.RESTART; + case (ResponseCode.RequestError): + case (ResponseCode.SystemError): + LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetAllClientsList")); + return FLClientStatus.FAILED; + default: + LOGGER.severe(Common.addTag("[PairWiseMask] the return from server in ReturnAllClientList " + + "is invalid: " + retCode)); + return FLClientStatus.FAILED; + } + } + + private byte[] transStringListToByte(List stringList) { + int byteNum = 0; + for (String value : stringList) { + byte[] stringByte = value.getBytes(StandardCharsets.UTF_8); + byteNum += stringByte.length; + } + byte[] concatData = new byte[byteNum]; + int offset = 0; + for (String str : stringList) { + byte[] stringByte = str.getBytes(StandardCharsets.UTF_8); + System.arraycopy(stringByte, 0, concatData, offset, stringByte.length); + offset += stringByte.length; + } + return concatData; + } + + /** + * Add signature on timestamp and iteration + * + * @param dateTime the timestamp of data + * @param iteration iteration number + * @return signed time and iteration + */ + public static byte[] signTimeAndIter(String dateTime, int iteration) { + // signature + FLParameter flParameter = FLParameter.getInstance(); + String clientID = flParameter.getClientID(); + boolean isPkiVerify = flParameter.isPkiVerify(); + byte[] signature = new byte[0]; + if (isPkiVerify) { + LOGGER.info(Common.addTag("ClientID is:" + clientID)); + byte[] concatData = concatenateIterAndTime(dateTime, iteration); + signature = SignAndVerify.signData(clientID, concatData); + } + return signature; + } + + private static String transformX509ToPem(X509Certificate x509Certificate) { + if (x509Certificate == null) { + LOGGER.severe(Common.addTag("[CertVerify] x509Certificate is null, please check!")); + return null; + } + String pemCert; + try { + byte[] derCert = x509Certificate.getEncoded(); + pemCert = new String(Base64.getEncoder().encode(derCert)); + } catch (CertificateEncodingException e) { + LOGGER.severe(Common.addTag("[CertVerify] catch Exception: " + e.getMessage())); + return null; + } + return pemCert; + } + + private static String[] transformX509ArrayToPemArray(X509Certificate[] x509Certificates) { + if (x509Certificates == null || x509Certificates.length == 0) { + LOGGER.severe(Common.addTag("[CertVerify] certificateChains is null or empty, please check!")); + throw new IllegalArgumentException(); + } + int nSize = x509Certificates.length; + String[] pemCerts = new String[nSize]; + for (int i = 0; i < nSize; ++i) { + String pemCert = transformX509ToPem(x509Certificates[i]); + pemCerts[i] = pemCert; + } + return pemCerts; + } } \ No newline at end of file diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/Common.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/Common.java index 8cf8946e7ee..d2fd875055d 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/Common.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/Common.java @@ -16,6 +16,14 @@ package com.mindspore.flclient; +import com.mindspore.flclient.model.AlInferBert; +import com.mindspore.flclient.model.AlTrainBert; +import com.mindspore.flclient.model.Client; +import com.mindspore.flclient.model.ClientManager; +import com.mindspore.flclient.model.SessionUtil; +import com.mindspore.flclient.model.Status; +import com.mindspore.flclient.model.TrainLenet; +import mindspore.schema.ResponseCode; import org.bouncycastle.crypto.BlockCipher; import org.bouncycastle.crypto.engines.AESEngine; import org.bouncycastle.crypto.prng.SP800SecureRandomBuilder; @@ -33,6 +41,9 @@ import java.util.logging.Logger; import java.util.regex.Matcher; import java.util.regex.Pattern; +import static com.mindspore.flclient.LocalFLParameter.ALBERT; +import static com.mindspore.flclient.LocalFLParameter.LENET; + /** * Define basic global methods used in federated learning task. * @@ -44,6 +55,9 @@ public class Common { */ public static final String LOG_TITLE = " "; + public static final String LOG_DEPRECATED = "This method will be deprecated in the next version, it is " + + "recommended to " + + "use the latest method according to the use cases in the official website tutorial"; /** * The list of trust flName. */ @@ -63,8 +77,16 @@ public class Common { * The tag when server is not ready. */ public static final String JOB_NOT_AVAILABLE = "The server's training job is disabled or finished."; + + /** + * use to stop job. + */ + private static final Object STOP_OBJECT = new Object(); private static final Logger LOGGER = Logger.getLogger(Common.class.toString()); + private static List envTrustList = new ArrayList<>(Arrays.asList("x86", "android")); private static SecureRandom secureRandom; + private static int iteration; + private static boolean isHttps; /** * Generate the URL for device-sever interaction @@ -107,6 +129,7 @@ public class Common { throw new IllegalArgumentException(); } String tag = domainName.split("//")[0] + "//"; + setIsHttps(domainName.split("//")[0].split(":")[0]); Random rand = new Random(); int randomNum = rand.nextInt(100000) % serverNum + port; url = tag + ip + ":" + String.valueOf(randomNum); @@ -167,6 +190,17 @@ public class Common { return (FL_NAME_TRUST_LIST.contains(flName)); } + /** + * Check if the deploy environment set by user is in the trust list. + * + * @param env the deploy environment for federated learning task set by user. + * @return boolean value, true indicates the deploy environment set by user is valid, false indicates the deploy + * environment set by user is not valid. + */ + public static boolean checkEnv(String env) { + return (envTrustList.contains(env)); + } + /** * Check whether the sslProtocol set by user is in the trust list. * @@ -184,11 +218,23 @@ public class Common { * @param millis the waiting time (ms). */ public static void sleep(long millis) { - try { - Thread.sleep(millis); // 1000 milliseconds is one second. - } catch (InterruptedException ex) { - LOGGER.severe(addTag("[sleep] catch InterruptedException: " + ex.getMessage())); - Thread.currentThread().interrupt(); + if (millis > 0) { + try { + synchronized (STOP_OBJECT) { + STOP_OBJECT.wait(millis); // 1000 milliseconds is one second. + } + } catch (InterruptedException ex) { + LOGGER.severe(addTag("[sleep] catch InterruptedException: " + ex.getMessage())); + } + } + } + + /** + * Use to stop the wait method of a Object. + */ + public static void notifyObject() { + synchronized (STOP_OBJECT) { + STOP_OBJECT.notify(); } } @@ -261,10 +307,20 @@ public class Common { String messageStr = new String(message); if (messageStr.contains(SAFE_MOD)) { LOGGER.info(Common.addTag("[isSeverReady] " + SAFE_MOD + ", need wait some time and request again")); + if (messageStr.split(":").length == 2) { + iteration = Integer.parseInt(messageStr.split(":")[1]); + } else { + LOGGER.info(Common.addTag("[isSeverReady] the server does not return the current iteration.")); + } return false; } else if (messageStr.contains(JOB_NOT_AVAILABLE)) { LOGGER.info(Common.addTag("[isSeverReady] " + JOB_NOT_AVAILABLE + ", need wait some time and request " + "again")); + if (messageStr.split(":").length == 2) { + iteration = Integer.parseInt(messageStr.split(":")[1]); + } else { + LOGGER.info(Common.addTag("[isSeverReady] the server does not return the current iteration.")); + } return false; } else { return true; @@ -307,7 +363,7 @@ public class Common { * Check whether the path set by user exists. * * @param path the path set by user. - * @return boolean value, true indicates the path is exist, false indicates the path is not exist + * @return boolean value, true indicates the path is exist, false indicates the path does not exist */ public static boolean checkPath(String path) { if (path == null) { @@ -323,7 +379,7 @@ public class Common { LOGGER.info(addTag("[check path " + i + "] " + paths[i])); File file = new File(paths[i]); if (!file.exists()) { - LOGGER.severe(Common.addTag("[checkPath] the path is not exist, please check")); + LOGGER.severe(Common.addTag("[checkPath] the path does not exist, please check")); return false; } } @@ -415,4 +471,131 @@ public class Common { throw new IllegalArgumentException(); } } + + /** + * Record the current iteration when server is not ready. + * + * @return int value, the current iteration when server is not ready. + */ + public static int getIteration() { + return iteration; + } + + /** + * Determine whether to conduct https communication according to the domain name set by user. + * + * @return boolean value, true means conducting https communication, false means conducting http communication. + */ + public static boolean isHttps() { + return isHttps; + } + + public static void setIsHttps(String tag) { + if ("https".equals(tag)) { + LOGGER.info(Common.addTag("conducting https communication")); + Common.isHttps = true; + } else if ("http".equals(tag)) { + LOGGER.info(Common.addTag("conducting http communication")); + Common.isHttps = false; + } else { + LOGGER.info(Common.addTag("The domain header set by the user is incorrect, please check")); + throw new IllegalArgumentException(); + } + } + + + /** + * Initialization session. + * + * @return the status code in client. + */ + public static FLClientStatus initSession(String modelPath) { + FLParameter flParameter = FLParameter.getInstance(); + LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); + if (Common.checkFLName(flParameter.getFlName())) { + return deprecatedInitSession(); + } + Status tag; + LOGGER.info(Common.addTag("==========Loading model, " + modelPath + " Create " + + " Session=============")); + Client client = ClientManager.getClient(flParameter.getFlName()); + tag = client.initSessionAndInputs(modelPath, localFLParameter.getMsConfig()); + if (!Status.SUCCESS.equals(tag)) { + LOGGER.severe(Common.addTag("[initSession] unsolved error code in : the return " + + "is -1")); + return FLClientStatus.FAILED; + } + return FLClientStatus.SUCCESS; + } + + /** + * Free session. + */ + protected static void freeSession() { + FLParameter flParameter = FLParameter.getInstance(); + if (Common.checkFLName(flParameter.getFlName())) { + deprecatedFreeSession(); + } else { + LOGGER.info(Common.addTag("===========free session=============")); + Client client = ClientManager.getClient(flParameter.getFlName()); + client.free(); + } + } + + /** + * Initialization session. + * + * @return the status code in client. + */ + private static FLClientStatus deprecatedInitSession() { + FLParameter flParameter = FLParameter.getInstance(); + int tag = 0; + if (flParameter.getFlName().equals(ALBERT)) { + LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " + + "Train Session=============")); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true); + if (tag == -1) { + LOGGER.severe(Common.addTag("[initSession] unsolved error code in : the return " + + "is -1")); + return FLClientStatus.FAILED; + } + LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " " + + "Create inference Session=============")); + AlInferBert alInferBert = AlInferBert.getInstance(); + tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false); + } else if (flParameter.getFlName().equals(LENET)) { + LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " + + "Train Session=============")); + TrainLenet trainLenet = TrainLenet.getInstance(); + tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true); + } + if (tag == -1) { + LOGGER.severe(Common.addTag("[initSession] unsolved error code in : the return is " + + "-1")); + return FLClientStatus.FAILED; + } + return FLClientStatus.SUCCESS; + } + + /** + * Free session. + */ + private static void deprecatedFreeSession() { + FLParameter flParameter = FLParameter.getInstance(); + if (flParameter.getFlName().equals(ALBERT)) { + LOGGER.info(Common.addTag("===========free train session=============")); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + SessionUtil.free(alTrainBert.getTrainSession()); + if (!flParameter.getTestDataset().equals("null")) { + LOGGER.info(Common.addTag("===========free inference session=============")); + AlInferBert alInferBert = AlInferBert.getInstance(); + SessionUtil.free(alInferBert.getTrainSession()); + } + } else if (flParameter.getFlName().equals(LENET)) { + LOGGER.info(Common.addTag("===========free session=============")); + TrainLenet trainLenet = TrainLenet.getInstance(); + SessionUtil.free(trainLenet.getTrainSession()); + } + } } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLCommunication.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLCommunication.java index 74809ff11e8..7667d8c0092 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLCommunication.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLCommunication.java @@ -17,6 +17,8 @@ package com.mindspore.flclient; import static com.mindspore.flclient.FLParameter.TIME_OUT; +import static com.mindspore.flclient.LocalFLParameter.ANDROID; +import static com.mindspore.flclient.LocalFLParameter.X86; import okhttp3.Call; import okhttp3.Callback; @@ -49,12 +51,16 @@ import javax.net.ssl.X509TrustManager; */ public class FLCommunication implements IFLCommunication { private static int timeOut; - private static boolean ifCertificateVerify = false; + private static String sslProtocol; + private static String env; + private static SSLSocketFactory sslSocketFactory; + private static X509TrustManager x509TrustManager; private static final MediaType MEDIA_TYPE_JSON = MediaType.parse("applicatiom/json;charset=utf-8"); private static final Logger LOGGER = Logger.getLogger(FLCommunication.class.toString()); private static volatile FLCommunication communication; private FLParameter flParameter = FLParameter.getInstance(); + private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); private OkHttpClient client; private FLCommunication() { @@ -63,7 +69,14 @@ public class FLCommunication implements IFLCommunication { } else { timeOut = TIME_OUT; } - ifCertificateVerify = flParameter.isUseSSL(); + sslProtocol = flParameter.getSslProtocol(); + if (!Common.checkFLName(flParameter.getFlName())) { + env = flParameter.getDeployEnv(); + if (ANDROID.equals(env)) { + sslSocketFactory = flParameter.getSslSocketFactory(); + x509TrustManager = flParameter.getX509TrustManager(); + } + } client = getOkHttpClient(); } @@ -85,33 +98,29 @@ public class FLCommunication implements IFLCommunication { } }; final TrustManager[] trustAllCerts = new TrustManager[]{trustManager}; - try { - LOGGER.info(Common.addTag("the set timeOut in OkHttpClient: " + timeOut)); - OkHttpClient.Builder builder = new OkHttpClient.Builder(); - builder.connectTimeout(timeOut, TimeUnit.SECONDS); - builder.writeTimeout(timeOut, TimeUnit.SECONDS); - builder.readTimeout(3 * timeOut, TimeUnit.SECONDS); - if (ifCertificateVerify) { - builder.sslSocketFactory(SSLSocketFactoryTools.getInstance().getmSslSocketFactory(), - SSLSocketFactoryTools.getInstance().getmTrustManager()); - builder.hostnameVerifier(SSLSocketFactoryTools.getInstance().getHostnameVerifier()); - } else { - final SSLContext sslContext = SSLContext.getInstance("TLS"); - sslContext.init(null, trustAllCerts, Common.getSecureRandom()); - final SSLSocketFactory sslFactory = sslContext.getSocketFactory(); - builder.sslSocketFactory(sslFactory, trustManager); + LOGGER.info(Common.addTag("the set timeOut in OkHttpClient: " + timeOut)); + OkHttpClient.Builder builder = new OkHttpClient.Builder(); + builder.connectTimeout(timeOut, TimeUnit.SECONDS); + builder.writeTimeout(timeOut, TimeUnit.SECONDS); + builder.readTimeout(3 * timeOut, TimeUnit.SECONDS); + if (Common.isHttps()) { + if (ANDROID.equals(env)) { + builder.sslSocketFactory(sslSocketFactory, x509TrustManager); builder.hostnameVerifier(new HostnameVerifier() { @Override public boolean verify(String arg0, SSLSession arg1) { return true; } }); + } else { + builder.sslSocketFactory(SSLSocketFactoryTools.getInstance().getmSslSocketFactory(), + SSLSocketFactoryTools.getInstance().getmTrustManager()); + builder.hostnameVerifier(SSLSocketFactoryTools.getInstance().getHostnameVerifier()); } - return builder.build(); - } catch (NoSuchAlgorithmException | KeyManagementException ex) { - LOGGER.severe(Common.addTag("[OkHttpClient] catch NoSuchAlgorithmException or KeyManagementException: " + ex.getMessage())); - throw new IllegalArgumentException(ex); + } else { + LOGGER.info(Common.addTag("conducting http communication, do not need SSLSocketFactoryTools")); } + return builder.build(); } /** diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLJobResultCallback.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLJobResultCallback.java index 77d52d116a1..f425d09c643 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLJobResultCallback.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLJobResultCallback.java @@ -33,6 +33,7 @@ public class FLJobResultCallback implements IFLJobResultCallback { * @param iterationSeq Iteration number * @param resultCode Status Code */ + @Override public void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode) { LOGGER.info(Common.addTag("[onFlJobIterationFinished] modelName: " + modelName + " iterationSeq: " + iterationSeq + " resultCode: " + resultCode)); @@ -45,6 +46,7 @@ public class FLJobResultCallback implements IFLJobResultCallback { * @param iterationCount total Iteration numbers * @param resultCode Status Code */ + @Override public void onFlJobFinished(String modelName, int iterationCount, int resultCode) { LOGGER.info(Common.addTag("[onFlJobFinished] modelName: " + modelName + " iterationCount: " + iterationCount + " resultCode: " + resultCode)); diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java index aedc1924c47..8f232162c89 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java @@ -22,8 +22,16 @@ import static com.mindspore.flclient.LocalFLParameter.LENET; import com.mindspore.flclient.model.AlInferBert; import com.mindspore.flclient.model.AlTrainBert; +import com.mindspore.flclient.model.Client; +import com.mindspore.flclient.model.ClientManager; +import com.mindspore.flclient.model.CommonUtils; +import com.mindspore.flclient.model.RunType; import com.mindspore.flclient.model.SessionUtil; +import com.mindspore.flclient.model.Status; import com.mindspore.flclient.model.TrainLenet; +import com.mindspore.flclient.pki.PkiBean; +import com.mindspore.flclient.pki.PkiUtil; +import com.mindspore.lite.MSTensor; import mindspore.schema.CipherPublicParams; import mindspore.schema.FLPlan; @@ -36,6 +44,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.Date; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.TreeMap; import java.util.logging.Logger; @@ -54,7 +63,7 @@ public class FLLiteClient { private double dpNormClipAdapt = 0.05d; private FLCommunication flCommunication; private FLClientStatus status; - private int retCode; + private int retCode = ResponseCode.RequestError; private int iterations = 1; private int epochs = 1; private int batchSize = 16; @@ -68,12 +77,15 @@ public class FLLiteClient { private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); private SecureProtocol secureProtocol = new SecureProtocol(); private String nextRequestTime; + private Client client; + private Map oldFeatureMap; /** * Defining a constructor of teh class FLLiteClient. */ public FLLiteClient() { flCommunication = FLCommunication.getInstance(); + client = ClientManager.getClient(flParameter.getFlName()); } private int setGlobalParameters(ResponseFLJob flJob) { @@ -87,22 +99,15 @@ public class FLLiteClient { batchSize = flPlan.miniBatch(); String serverMod = flPlan.serverMode(); localFLParameter.setServerMod(serverMod); - LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + serverMod)); - if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { - LOGGER.info(Common.addTag("[startFLJob] set for AlTrainBert: " + batchSize)); - AlTrainBert alTrainBert = AlTrainBert.getInstance(); - alTrainBert.setBatchSize(batchSize); - } else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) { - LOGGER.info(Common.addTag("[startFLJob] set for TrainLenet: " + batchSize)); - TrainLenet trainLenet = TrainLenet.getInstance(); - trainLenet.setBatchSize(batchSize); + if (Common.checkFLName(flParameter.getFlName())) { + deprecatedSetBatchSize(batchSize); } else { - LOGGER.severe(Common.addTag("[startFLJob] the ServerMod returned from server is not valid")); - return -1; + LOGGER.info(Common.addTag("[startFLJob] not set for client: " + batchSize)); } - LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + iterations)); - LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + epochs)); - LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + batchSize)); + LOGGER.info(Common.addTag("[startFLJob] the GlobalParameter from server: " + serverMod)); + LOGGER.info(Common.addTag("[startFLJob] the GlobalParameter from server: " + iterations)); + LOGGER.info(Common.addTag("[startFLJob] the GlobalParameter from server: " + epochs)); + LOGGER.info(Common.addTag("[startFLJob] the GlobalParameter from server: " + batchSize)); CipherPublicParams cipherPublicParams = flPlan.cipher(); if (cipherPublicParams == null) { LOGGER.severe(Common.addTag("[startFLJob] the cipherPublicParams returned from server is null")); @@ -232,7 +237,12 @@ public class FLLiteClient { StartFLJob startFLJob = StartFLJob.getInstance(); Date date = new Date(); long time = date.getTime(); - byte[] msg = startFLJob.getRequestStartFLJob(trainDataSize, iteration, time); + + PkiBean pkiBean = null; + if (flParameter.isPkiVerify()) { + pkiBean = PkiUtil.genPkiBean(flParameter.getClientID(), time); + } + byte[] msg = startFLJob.getRequestStartFLJob(trainDataSize, iteration, time, pkiBean); try { long start = Common.startTime("single startFLJob"); LOGGER.info(Common.addTag("[startFLJob] the request message length: " + msg.length)); @@ -243,6 +253,7 @@ public class FLLiteClient { status = FLClientStatus.RESTART; Common.sleep(SLEEP_TIME); nextRequestTime = ""; + retCode = ResponseCode.OutOfTime; return status; } LOGGER.info(Common.addTag("[startFLJob] the response message length: " + message.length)); @@ -250,12 +261,9 @@ public class FLLiteClient { ByteBuffer buffer = ByteBuffer.wrap(message); ResponseFLJob responseDataBuf = ResponseFLJob.getRootAsResponseFLJob(buffer); status = judgeStartFLJob(startFLJob, responseDataBuf); - retCode = responseDataBuf.retcode(); } catch (IOException e) { - LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in StartFLJob: catch IOException: " + - e.getMessage())); - status = FLClientStatus.FAILED; - retCode = ResponseCode.RequestError; + failed("[startFLJob] unsolved error code in StartFLJob: catch IOException: " + e.getMessage(), + ResponseCode.RequestError); } return status; } @@ -263,12 +271,13 @@ public class FLLiteClient { private FLClientStatus judgeStartFLJob(StartFLJob startFLJob, ResponseFLJob responseDataBuf) { iteration = responseDataBuf.iteration(); FLClientStatus response = startFLJob.doResponse(responseDataBuf); + retCode = startFLJob.getRetCode(); status = response; switch (response) { case SUCCESS: LOGGER.info(Common.addTag("[startFLJob] startFLJob success")); featureSize = startFLJob.getFeatureSize(); - secureProtocol.setEncryptFeatureName(startFLJob.getEncryptFeatureName()); + secureProtocol.setUpdateFeatureName(startFLJob.getUpdateFeatureName()); LOGGER.info(Common.addTag("[startFLJob] ***the feature size get in ResponseFLJob***: " + featureSize)); int tag = setGlobalParameters(responseDataBuf); if (tag == -1) { @@ -297,38 +306,37 @@ public class FLLiteClient { return status; } + private FLClientStatus trainLoop() { + retCode = ResponseCode.SUCCEED; + status = Common.initSession(flParameter.getTrainModelPath()); + if (status == FLClientStatus.FAILED) { + retCode = ResponseCode.RequestError; + return status; + } + retCode = ResponseCode.SUCCEED; + LOGGER.info(Common.addTag("[train] train in " + flParameter.getFlName())); + Status tag = client.trainModel(epochs); + if (!Status.SUCCESS.equals(tag)) { + failed("[train] unsolved error code in ", ResponseCode.RequestError); + } + client.saveModel(flParameter.getTrainModelPath()); + Common.freeSession(); + return status; + } + /** * Define the training process. * * @return the status code corresponding to the response message. */ public FLClientStatus localTrain() { + LOGGER.info(Common.addTag("[train] ====================================global train epoch " + iteration + "====================================")); - status = FLClientStatus.SUCCESS; - retCode = ResponseCode.SUCCEED; - if (flParameter.getFlName().equals(ALBERT)) { - LOGGER.info(Common.addTag("[train] train in albert")); - AlTrainBert alTrainBert = AlTrainBert.getInstance(); - int tag = alTrainBert.trainModel(flParameter.getTrainModelPath(), epochs); - if (tag == -1) { - LOGGER.severe(Common.addTag("[train] unsolved error code in ")); - status = FLClientStatus.FAILED; - retCode = ResponseCode.RequestError; - } - } else if (flParameter.getFlName().equals(LENET)) { - LOGGER.info(Common.addTag("[train] train in lenet")); - TrainLenet trainLenet = TrainLenet.getInstance(); - int tag = trainLenet.trainModel(flParameter.getTrainModelPath(), epochs); - if (tag == -1) { - LOGGER.severe(Common.addTag("[train] unsolved error code in ")); - status = FLClientStatus.FAILED; - retCode = ResponseCode.RequestError; - } + if (Common.checkFLName(flParameter.getFlName())) { + status = deprecatedTrainLoop(); } else { - LOGGER.severe(Common.addTag("[train] the flName is not valid")); - status = FLClientStatus.FAILED; - retCode = ResponseCode.RequestError; + status = trainLoop(); } return status; } @@ -357,6 +365,7 @@ public class FLLiteClient { status = FLClientStatus.RESTART; Common.sleep(SLEEP_TIME); nextRequestTime = ""; + retCode = ResponseCode.OutOfTime; return status; } LOGGER.info(Common.addTag("[updateModel] the response message length: " + message.length)); @@ -364,16 +373,14 @@ public class FLLiteClient { ByteBuffer debugBuffer = ByteBuffer.wrap(message); ResponseUpdateModel responseDataBuf = ResponseUpdateModel.getRootAsResponseUpdateModel(debugBuffer); status = updateModelBuf.doResponse(responseDataBuf); - retCode = responseDataBuf.retcode(); + retCode = updateModelBuf.getRetCode(); if (status == FLClientStatus.RESTART) { nextRequestTime = responseDataBuf.nextReqTime(); } LOGGER.info(Common.addTag("[updateModel] get response from server ok!")); } catch (IOException e) { - LOGGER.severe(Common.addTag("[updateModel] unsolved error code in updateModel: catch IOException: " + - e.getMessage())); - status = FLClientStatus.FAILED; - retCode = ResponseCode.RequestError; + failed("[updateModel] unsolved error code in updateModel: catch IOException: " + e.getMessage(), + ResponseCode.RequestError); } return status; } @@ -396,6 +403,7 @@ public class FLLiteClient { LOGGER.info(Common.addTag("[getModel] the server is not ready now, need wait some time and request " + "again")); status = FLClientStatus.WAIT; + retCode = ResponseCode.SucNotReady; return status; } LOGGER.info(Common.addTag("[getModel] the response message length: " + message.length)); @@ -404,19 +412,35 @@ public class FLLiteClient { ByteBuffer debugBuffer = ByteBuffer.wrap(message); ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer); status = getModelBuf.doResponse(responseDataBuf); - retCode = responseDataBuf.retcode(); + retCode = getModelBuf.getRetCode(); if (status == FLClientStatus.RESTART) { nextRequestTime = responseDataBuf.timestamp(); } LOGGER.info(Common.addTag("[getModel] get response from server ok!")); } catch (IOException e) { - LOGGER.severe(Common.addTag("[getModel] un sloved error code: catch IOException: " + e.getMessage())); - status = FLClientStatus.FAILED; - retCode = ResponseCode.RequestError; + failed("[getModel] un sloved error code: catch IOException: " + e.getMessage(), ResponseCode.RequestError); } return status; } + private Map getFeatureMap() { + Map featureMap = new HashMap<>(); + if (Common.checkFLName(flParameter.getFlName())) { + featureMap = deprecatedGetFeatureMap(); + return featureMap; + } + status = Common.initSession(flParameter.getTrainModelPath()); + if (status == FLClientStatus.FAILED) { + Common.freeSession(); + retCode = ResponseCode.RequestError; + return new HashMap<>(); + } + List features = client.getFeatures(); + featureMap = CommonUtils.convertTensorToFeatures(features); + Common.freeSession(); + return featureMap; + } + /** * Obtain the weight of the model before training. * @@ -456,6 +480,58 @@ public class FLLiteClient { return mapBeforeTrain; } + private void getOldFeatureMap() { + EncryptLevel encryptLevel = localFLParameter.getEncryptLevel(); + if (encryptLevel == EncryptLevel.DP_ENCRYPT) { + Map featureMap = getFeatureMap(); + oldFeatureMap = getOldMapCopy(featureMap); + } + } + + public void updateDpNormClip() { + EncryptLevel encryptLevel = localFLParameter.getEncryptLevel(); + if (encryptLevel == EncryptLevel.DP_ENCRYPT) { + Map fedFeatureMap = getFeatureMap(); + float fedWeightUpdateNorm = calWeightUpdateNorm(oldFeatureMap, fedFeatureMap); + if (fedWeightUpdateNorm == -1) { + LOGGER.severe(Common.addTag("[updateDpNormClip] the returned value fedWeightUpdateNorm is not valid: " + + "-1, please check!")); + throw new IllegalArgumentException(); + } + LOGGER.info(Common.addTag("[DP] L2-norm of weights' average update is: " + fedWeightUpdateNorm)); + float newNormCLip = (float) getDpNormClipFactor() * fedWeightUpdateNorm; + if (iteration == 1) { + setDpNormClipAdapt(newNormCLip); + LOGGER.info(Common.addTag("[DP] dpNormClip has been updated.")); + } else { + if (newNormCLip < getDpNormClipAdapt()) { + setDpNormClipAdapt(newNormCLip); + LOGGER.info(Common.addTag("[DP] dpNormClip has been updated.")); + } + } + LOGGER.info(Common.addTag("[DP] Adaptive dpNormClip is: " + getDpNormClipAdapt())); + } + } + + private float calWeightUpdateNorm(Map originalData, Map newData) { + float updateL2Norm = 0f; + for (String key : originalData.keySet()) { + float[] data = originalData.get(key); + float[] dataAfterUpdate = newData.get(key); + for (int j = 0; j < data.length; j++) { + if (j >= dataAfterUpdate.length) { + LOGGER.severe("[calWeightUpdateNorm] the index j is out of range for array dataAfterUpdate, " + + "please check"); + return -1; + } + float updateData = data[j] - dataAfterUpdate[j]; + updateL2Norm += updateData * updateData; + } + } + updateL2Norm = (float) Math.sqrt(updateL2Norm); + return updateL2Norm; + } + /** * Obtain pairwise mask and individual mask. * @@ -477,16 +553,14 @@ public class FLLiteClient { localFLParameter.getEncryptLevel().toString() + "> : " + curStatus)); return curStatus; case DP_ENCRYPT: - Map map = new HashMap(); - if (flParameter.getFlName().equals(ALBERT)) { - AlTrainBert alTrainBert = AlTrainBert.getInstance(); - map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession())); - } else if (flParameter.getFlName().equals(LENET)) { - TrainLenet trainLenet = TrainLenet.getInstance(); - map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); + // get the feature map before train + getOldFeatureMap(); + if (oldFeatureMap.isEmpty()) { + LOGGER.severe(Common.addTag("[Encrypt] the return map in getOldFeatureMapis empty ")); + retCode = ResponseCode.RequestError; + return FLClientStatus.FAILED; } - Map copyMap = getOldMapCopy(map); - curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClipAdapt, copyMap); + curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClipAdapt, oldFeatureMap); retCode = ResponseCode.SUCCEED; if (curStatus != FLClientStatus.SUCCESS) { LOGGER.info(Common.addTag("---Differential privacy init failed---")); @@ -537,85 +611,50 @@ public class FLLiteClient { } } + + private FLClientStatus evaluateLoop() { + client.free(); + status = FLClientStatus.SUCCESS; + retCode = ResponseCode.SUCCEED; + + float acc = 0; + if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { + LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod())); + client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig()); + LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath())); + acc = client.evalModel(); + } else { + LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod())); + client.initSessionAndInputs(flParameter.getTrainModelPath(), localFLParameter.getMsConfig()); + LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getTrainModelPath())); + acc = client.evalModel(); + } + if (Float.isNaN(acc)) { + failed("[evaluate] unsolved error code in : the return acc is NAN", ResponseCode.RequestError); + return status; + } + LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc)); + return status; + } + + private void failed(String log, int retCode) { + LOGGER.severe(Common.addTag(log)); + status = FLClientStatus.FAILED; + this.retCode = retCode; + } + /** * Evaluate model after getting model from server. * * @return the status code in client. */ public FLClientStatus evaluateModel() { - status = FLClientStatus.SUCCESS; - retCode = ResponseCode.SUCCEED; LOGGER.info(Common.addTag("===================================evaluate model after getting model from " + "server===================================")); - if (flParameter.getFlName().equals(ALBERT)) { - float acc = 0; - if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { - LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod())); - AlInferBert alInferBert = AlInferBert.getInstance(); - int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), - flParameter.getIdsFile(), true); - if (dataSize <= 0) { - LOGGER.severe(Common.addTag("[evaluate] unsolved error code in : the " + - "return dataSize<=0")); - status = FLClientStatus.FAILED; - retCode = ResponseCode.RequestError; - return status; - } - acc = alInferBert.evalModel(); - } else { - LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod())); - AlTrainBert alTrainBert = AlTrainBert.getInstance(); - int dataSize = alTrainBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), - flParameter.getIdsFile()); - if (dataSize <= 0) { - LOGGER.severe(Common.addTag("[evaluate] unsolved error code in : the " + - "return dataSize<=0")); - status = FLClientStatus.FAILED; - retCode = ResponseCode.RequestError; - return status; - } - acc = alTrainBert.evalModel(); - } - if (Float.isNaN(acc)) { - LOGGER.severe(Common.addTag("[evaluate] unsolved error code in : the return acc is NAN")); - status = FLClientStatus.FAILED; - retCode = ResponseCode.RequestError; - return status; - } - LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " + - flParameter.getTestDataset() + " vocabFile: " + flParameter.getVocabFile() + - " idsFile: " + flParameter.getIdsFile())); - LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc)); - } else if (flParameter.getFlName().equals(LENET)) { - TrainLenet trainLenet = TrainLenet.getInstance(); - if (flParameter.getTestDataset().split(",").length < 2) { - LOGGER.severe(Common.addTag("[evaluate] the set testDataPath for lenet is not valid, should be the " + - "format of ")); - status = FLClientStatus.FAILED; - retCode = ResponseCode.RequestError; - return status; - } - int dataSize = trainLenet.initDataSet(flParameter.getTestDataset().split(",")[0], - flParameter.getTestDataset().split(",")[1]); - if (dataSize <= 0) { - LOGGER.severe(Common.addTag("[evaluate] unsolved error code in : the return " + - "dataSize<=0")); - status = FLClientStatus.FAILED; - retCode = ResponseCode.RequestError; - return status; - } - float acc = trainLenet.evalModel(); - if (Float.isNaN(acc)) { - LOGGER.severe(Common.addTag("[evaluate] unsolved error code in : the return acc" + - " is NAN")); - status = FLClientStatus.FAILED; - retCode = ResponseCode.RequestError; - return status; - } - LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " + - flParameter.getTestDataset().split(",")[0] + " labelPath: " + - flParameter.getTestDataset().split(",")[1])); - LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc)); + if (Common.checkFLName(flParameter.getFlName())) { + status = deprecatedEvaluateLoop(); + } else { + status = evaluateLoop(); } return status; } @@ -623,10 +662,43 @@ public class FLLiteClient { /** * Set date path. * - * @param dataPath, train or test dataset and label set. * @return date size. */ - public int setInput(String dataPath) { + public int setInput() { + int dataSize = 0; + if (Common.checkFLName(flParameter.getFlName())) { + dataSize = deprecatedSetInput(flParameter.getTrainDataset()); + return dataSize; + } + retCode = ResponseCode.SUCCEED; + LOGGER.info(Common.addTag("==========set input===========")); + + // train + dataSize = client.initDataSets(flParameter.getDataMap()).get(RunType.TRAINMODE); + if (dataSize <= 0) { + retCode = ResponseCode.RequestError; + return -1; + } + return dataSize; + } + + private int deprecatedSetBatchSize(int batchSize) { + if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { + LOGGER.info(Common.addTag("[startFLJob] set for AlTrainBert: " + batchSize)); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + alTrainBert.setBatchSize(batchSize); + } else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) { + LOGGER.info(Common.addTag("[startFLJob] set for TrainLenet: " + batchSize)); + TrainLenet trainLenet = TrainLenet.getInstance(); + trainLenet.setBatchSize(batchSize); + } else { + LOGGER.severe(Common.addTag("[startFLJob] the ServerMod returned from server is not valid")); + return -1; + } + return 0; + } + + private int deprecatedSetInput(String dataPath) { retCode = ResponseCode.SUCCEED; LOGGER.info(Common.addTag("==========set input===========")); int dataSize = 0; @@ -653,61 +725,139 @@ public class FLLiteClient { return dataSize; } - /** - * Initialization session. - * - * @return the status code in client. - */ - public FLClientStatus initSession() { - int tag = 0; + private FLClientStatus deprecatedTrainLoop() { + retCode = ResponseCode.SUCCEED; + status = Common.initSession(flParameter.getTrainModelPath()); + if (status == FLClientStatus.FAILED) { + retCode = ResponseCode.RequestError; + return status; + } + status = FLClientStatus.SUCCESS; retCode = ResponseCode.SUCCEED; if (flParameter.getFlName().equals(ALBERT)) { - LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " + - "Train Session=============")); + LOGGER.info(Common.addTag("[train] train in albert")); AlTrainBert alTrainBert = AlTrainBert.getInstance(); - tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true); + int tag = alTrainBert.trainModel(flParameter.getTrainModelPath(), epochs); if (tag == -1) { - LOGGER.severe(Common.addTag("[initSession] unsolved error code in : the return " + - "is -1")); - retCode = ResponseCode.RequestError; - return FLClientStatus.FAILED; + failed("[train] unsolved error code in ", ResponseCode.RequestError); } - LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " " + - "Create inference Session=============")); - AlInferBert alInferBert = AlInferBert.getInstance(); - tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false); } else if (flParameter.getFlName().equals(LENET)) { - LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " + - "Train Session=============")); + LOGGER.info(Common.addTag("[train] train in lenet")); TrainLenet trainLenet = TrainLenet.getInstance(); - tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true); + int tag = trainLenet.trainModel(flParameter.getTrainModelPath(), epochs); + if (tag == -1) { + failed("[train] unsolved error code in ", ResponseCode.RequestError); + } + } else { + failed("[train] the flName is not valid", ResponseCode.RequestError); } - if (tag == -1) { - LOGGER.severe(Common.addTag("[initSession] unsolved error code in : the return is " + - "-1")); - retCode = ResponseCode.RequestError; - return FLClientStatus.FAILED; - } - return FLClientStatus.SUCCESS; + Common.freeSession(); + return status; } - /** - * Free session. - */ - protected void freeSession() { - if (flParameter.getFlName().equals(ALBERT)) { - LOGGER.info(Common.addTag("===========free train session=============")); - AlTrainBert alTrainBert = AlTrainBert.getInstance(); - SessionUtil.free(alTrainBert.getTrainSession()); - if (!flParameter.getTestDataset().equals("null")) { - LOGGER.info(Common.addTag("===========free inference session=============")); - AlInferBert alInferBert = AlInferBert.getInstance(); - SessionUtil.free(alInferBert.getTrainSession()); - } - } else if (flParameter.getFlName().equals(LENET)) { - LOGGER.info(Common.addTag("===========free session=============")); - TrainLenet trainLenet = TrainLenet.getInstance(); - SessionUtil.free(trainLenet.getTrainSession()); + private Map deprecatedGetFeatureMap() { + status = Common.initSession(flParameter.getTrainModelPath()); + if (status == FLClientStatus.FAILED) { + Common.freeSession(); + retCode = ResponseCode.RequestError; + return new HashMap<>(); } + Map featureMap = new HashMap<>(); + if (flParameter.getFlName().equals(ALBERT)) { + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession())); + } else if (flParameter.getFlName().equals(LENET)) { + TrainLenet trainLenet = TrainLenet.getInstance(); + featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); + } + Common.freeSession(); + return featureMap; + } + + + private FLClientStatus deprecatedEvaluateLoop() { + status = FLClientStatus.SUCCESS; + retCode = ResponseCode.SUCCEED; + if (flParameter.getFlName().equals(ALBERT)) { + float acc = 0; + if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { + LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod())); + AlInferBert alInferBert = AlInferBert.getInstance(); + int tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false); + if (tag == -1) { + failed("[evaluate] unsolved error code in : the return is -1", + ResponseCode.RequestError); + return FLClientStatus.FAILED; + } + int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), + flParameter.getIdsFile(), true); + if (dataSize <= 0) { + failed("[evaluate] unsolved error code in : the return dataSize<=0", + ResponseCode.RequestError); + return status; + } + acc = alInferBert.evalModel(); + SessionUtil.free(alInferBert.getTrainSession()); + } else { + LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod())); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + int tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), false); + if (tag == -1) { + failed("[evaluate] unsolved error code in : the return is -1", + ResponseCode.RequestError); + return FLClientStatus.FAILED; + } + int dataSize = alTrainBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), + flParameter.getIdsFile()); + if (dataSize <= 0) { + failed("[evaluate] unsolved error code in : the return dataSize<=0", + ResponseCode.RequestError); + return status; + } + acc = alTrainBert.evalModel(); + SessionUtil.free(alTrainBert.getTrainSession()); + } + if (Float.isNaN(acc)) { + failed("[evaluate] unsolved error code in : the return acc is NAN", + ResponseCode.RequestError); + return status; + } + LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " + + flParameter.getTestDataset() + " vocabFile: " + flParameter.getVocabFile() + + " idsFile: " + flParameter.getIdsFile())); + LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc)); + } else if (flParameter.getFlName().equals(LENET)) { + TrainLenet trainLenet = TrainLenet.getInstance(); + if (flParameter.getTestDataset().split(",").length < 2) { + failed("[evaluate] the set testDataPath for lenet is not valid, should be the format of ", ResponseCode.RequestError); + return status; + } + int tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true); + if (tag == -1) { + failed("[evaluate] unsolved error code in : the return is -1", + ResponseCode.RequestError); + return FLClientStatus.FAILED; + } + int dataSize = trainLenet.initDataSet(flParameter.getTestDataset().split(",")[0], + flParameter.getTestDataset().split(",")[1]); + if (dataSize <= 0) { + failed("[evaluate] unsolved error code in : the return dataSize<=0", + ResponseCode.RequestError); + return status; + } + float acc = trainLenet.evalModel(); + SessionUtil.free(trainLenet.getTrainSession()); + if (Float.isNaN(acc)) { + failed("[evaluate] unsolved error code in : the return acc is NAN", + ResponseCode.RequestError); + return status; + } + LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " + + flParameter.getTestDataset().split(",")[0] + " labelPath: " + + flParameter.getTestDataset().split(",")[1])); + LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc)); + } + return status; } } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLParameter.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLParameter.java index 36cc1b3b925..1b602e8325d 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLParameter.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLParameter.java @@ -18,10 +18,19 @@ package com.mindspore.flclient; import static com.mindspore.flclient.LocalFLParameter.ALBERT; +import com.mindspore.flclient.model.RunType; + +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.UUID; import java.util.logging.Logger; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.X509TrustManager; + /** * Defines global parameters used during federated learning and these parameters are provided for users to set. * @@ -41,21 +50,39 @@ public class FLParameter { public static final int SLEEP_TIME = 1000; private static volatile FLParameter flParameter; + private String deployEnv; private String domainName; private String certPath; + + private SSLSocketFactory sslSocketFactory; + private X509TrustManager x509TrustManager; + private IFLJobResultCallback iflJobResultCallback = new FLJobResultCallback(); + private String trainDataset; private String vocabFile = "null"; private String idsFile = "null"; private String testDataset = "null"; + private boolean useSSL = false; private String flName; private String trainModelPath; private String inferModelPath; + private String sslProtocol = "TLSv1.2"; private String clientID; - private boolean useSSL = false; private int timeOut; private int sleepTime; private boolean ifUseElb = false; private int serverNum = 1; + private boolean ifPkiVerify = false; + private String equipCrlPath = "null"; + private long validIterInterval = 3600000L; + private int threadNum = 1; + private BindMode cpuBindMode = BindMode.NOT_BINDING_CORE; + + private List trainWeightName = new ArrayList<>(); + private List inferWeightName = new ArrayList<>(); + private Map> dataMap = new HashMap<>(); + private ServerMod serverMod; + private int batchSize; private FLParameter() { clientID = UUID.randomUUID().toString(); @@ -79,10 +106,28 @@ public class FLParameter { return localRef; } + public String getDeployEnv() { + if (deployEnv == null || deployEnv.isEmpty()) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please set it before using")); + throw new IllegalArgumentException(); + } + return deployEnv; + } + + public void setDeployEnv(String env) { + if (Common.checkEnv(env)) { + this.deployEnv = env; + } else { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is not in envTrustList: x86, android, " + + "please check it before setting")); + throw new IllegalArgumentException(); + } + } + public String getDomainName() { if (domainName == null || domainName.isEmpty()) { LOGGER.severe(Common.addTag("[flParameter] the parameter of is null or empty, please set it " + - "before use")); + "before using")); throw new IllegalArgumentException(); } return domainName; @@ -91,16 +136,33 @@ public class FLParameter { public void setDomainName(String domainName) { if (domainName == null || domainName.isEmpty() || (!("https".equals(domainName.split(":")[0]) || "http".equals(domainName.split(":")[0])))) { LOGGER.severe(Common.addTag("[flParameter] the parameter of is not valid, it should be like " + - "as https://...... or http://......, please check it before set")); + "as https://...... or http://......, please check it before setting")); throw new IllegalArgumentException(); } this.domainName = domainName; } + public String getClientID() { + if (clientID == null || clientID.isEmpty()) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null or empty, please check")); + throw new IllegalArgumentException(); + } + return clientID; + } + + public void setClientID(String clientID) { + if (clientID == null || clientID.isEmpty()) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null or empty, please check " + + "before setting")); + throw new IllegalArgumentException(); + } + this.clientID = clientID; + } + public String getCertPath() { if (certPath == null || certPath.isEmpty()) { - LOGGER.severe(Common.addTag("[flParameter] the parameter of is null or empty, please set it " + - "before use")); + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null or empty, the " + + "must be set when conducting https communication, please set it by FLParameter.setCertPath()")); throw new IllegalArgumentException(); } return certPath; @@ -111,28 +173,70 @@ public class FLParameter { if (Common.checkPath(realCertPath)) { this.certPath = realCertPath; } else { - LOGGER.severe(Common.addTag("[flParameter] the parameter of is not exist, please check it " + - "before set")); + LOGGER.severe(Common.addTag("[flParameter] the parameter of does not exist, it must be a valid" + + " path when conducting https communication, please check it before setting")); throw new IllegalArgumentException(); } } + public SSLSocketFactory getSslSocketFactory() { + if (sslSocketFactory == null) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, the " + + " must be set when the deployEnv being \"android\", please set it by " + + "FLParameter.setSslSocketFactory()")); + throw new IllegalArgumentException(); + } + return sslSocketFactory; + } + + public void setSslSocketFactory(SSLSocketFactory sslSocketFactory) { + this.sslSocketFactory = sslSocketFactory; + } + + public X509TrustManager getX509TrustManager() { + if (x509TrustManager == null) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, the " + + " must be set when the deployEnv being \"android\", please set it by " + + "FLParameter.setX509TrustManager()")); + throw new IllegalArgumentException(); + } + return x509TrustManager; + } + + public void setX509TrustManager(X509TrustManager x509TrustManager) { + this.x509TrustManager = x509TrustManager; + } + + public IFLJobResultCallback getIflJobResultCallback() { + if (iflJobResultCallback == null) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please set it" + + " before using")); + throw new IllegalArgumentException(); + } + return iflJobResultCallback; + } + + public void setIflJobResultCallback(IFLJobResultCallback iflJobResultCallback) { + this.iflJobResultCallback = iflJobResultCallback; + } + public String getTrainDataset() { if (trainDataset == null || trainDataset.isEmpty()) { LOGGER.severe(Common.addTag("[flParameter] the parameter of is null or empty, please set " + - "it before use")); + "it before using")); throw new IllegalArgumentException(); } return trainDataset; } public void setTrainDataset(String trainDataset) { + LOGGER.warning(Common.addTag(Common.LOG_DEPRECATED)); String realTrainDataset = Common.getRealPath(trainDataset); if (Common.checkPath(realTrainDataset)) { this.trainDataset = realTrainDataset; } else { - LOGGER.severe(Common.addTag("[flParameter] the parameter of is not exist, please check it " + - "before set")); + LOGGER.severe(Common.addTag("[flParameter] the parameter of does not exist, please check " + + "it before setting")); throw new IllegalArgumentException(); } } @@ -140,38 +244,41 @@ public class FLParameter { public String getVocabFile() { if ("null".equals(vocabFile) && ALBERT.equals(flName)) { LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please set it before " + - "use")); + "using")); throw new IllegalArgumentException(); } return vocabFile; } public void setVocabFile(String vocabFile) { + LOGGER.warning(Common.addTag(Common.LOG_DEPRECATED)); String realVocabFile = Common.getRealPath(vocabFile); if (Common.checkPath(realVocabFile)) { this.vocabFile = realVocabFile; } else { - LOGGER.severe(Common.addTag("[flParameter] the parameter of is not exist, please check it " + - "before set")); + LOGGER.severe(Common.addTag("[flParameter] the parameter of does not exist, please check it " + + "before setting")); throw new IllegalArgumentException(); } } public String getIdsFile() { if ("null".equals(idsFile) && ALBERT.equals(flName)) { - LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please set it before use")); + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please set it before " + + "using")); throw new IllegalArgumentException(); } return idsFile; } public void setIdsFile(String idsFile) { + LOGGER.warning(Common.addTag(Common.LOG_DEPRECATED)); String realIdsFile = Common.getRealPath(idsFile); if (Common.checkPath(realIdsFile)) { this.idsFile = realIdsFile; } else { - LOGGER.severe(Common.addTag("[flParameter] the parameter of is not exist, please check it " + - "before set")); + LOGGER.severe(Common.addTag("[flParameter] the parameter of does not exist, please check it " + + "before setting")); throw new IllegalArgumentException(); } } @@ -181,12 +288,13 @@ public class FLParameter { } public void setTestDataset(String testDataset) { + LOGGER.warning(Common.addTag(Common.LOG_DEPRECATED)); String realTestDataset = Common.getRealPath(testDataset); if (Common.checkPath(realTestDataset)) { this.testDataset = realTestDataset; } else { - LOGGER.severe(Common.addTag("[flParameter] the parameter of is not exist, please check it " + - "before set")); + LOGGER.severe(Common.addTag("[flParameter] the parameter of does not exist, please check it" + + " before setting")); throw new IllegalArgumentException(); } } @@ -194,27 +302,20 @@ public class FLParameter { public String getFlName() { if (flName == null || flName.isEmpty()) { LOGGER.severe(Common.addTag("[flParameter] the parameter of is null or empty, please set it " + - "before use")); + "before using")); throw new IllegalArgumentException(); } return flName; } public void setFlName(String flName) { - if (Common.checkFLName(flName)) { - this.flName = flName; - } else { - LOGGER.severe(Common.addTag("[flParameter] the parameter of is not in FL_NAME_TRUST_LIST: " + - Arrays.toString(Common.FL_NAME_TRUST_LIST.toArray(new String[0])) + ", please check it before " + - "set")); - throw new IllegalArgumentException(); - } + this.flName = flName; } public String getTrainModelPath() { if (trainModelPath == null || trainModelPath.isEmpty()) { LOGGER.severe(Common.addTag("[flParameter] the parameter of is null or empty, please set" + - " it before use")); + " it before using")); throw new IllegalArgumentException(); } return trainModelPath; @@ -225,8 +326,8 @@ public class FLParameter { if (Common.checkPath(realTrainModelPath)) { this.trainModelPath = realTrainModelPath; } else { - LOGGER.severe(Common.addTag("[flParameter] the parameter of is not exist, please check " + - "it before set")); + LOGGER.severe(Common.addTag("[flParameter] the parameter of does not exist, please " + + "check it before setting")); throw new IllegalArgumentException(); } } @@ -234,7 +335,7 @@ public class FLParameter { public String getInferModelPath() { if (inferModelPath == null || inferModelPath.isEmpty()) { LOGGER.severe(Common.addTag("[flParameter] the parameter of is null or empty, please set" + - " it before use")); + " it before using")); throw new IllegalArgumentException(); } return inferModelPath; @@ -245,8 +346,8 @@ public class FLParameter { if (Common.checkPath(realInferModelPath)) { this.inferModelPath = realInferModelPath; } else { - LOGGER.severe(Common.addTag("[flParameter] the parameter of is not exist, please check " + - "it before set")); + LOGGER.severe(Common.addTag("[flParameter] the parameter of does not exist, please check" + + " it before setting")); throw new IllegalArgumentException(); } } @@ -256,9 +357,31 @@ public class FLParameter { } public void setUseSSL(boolean useSSL) { + LOGGER.warning(Common.addTag("Certificate authentication is required for https communication,this parameter " + + "is true by default and no need to set it, " + Common.LOG_DEPRECATED)); this.useSSL = useSSL; } + public String getSslProtocol() { + if (sslProtocol == null || sslProtocol.isEmpty()) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null or empty, please set it" + + " before using")); + throw new IllegalArgumentException(); + } + return sslProtocol; + } + + public void setSslProtocol(String sslProtocol) { + if (Common.checkSSLProtocol(sslProtocol)) { + this.sslProtocol = sslProtocol; + } else { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is not in sslProtocolTrustList " + + ": " + Arrays.toString(Common.SSL_PROTOCOL_TRUST_LIST.toArray(new String[0])) + ", please check " + + "it before setting")); + throw new IllegalArgumentException(); + } + } + public int getTimeOut() { return timeOut; } @@ -286,7 +409,7 @@ public class FLParameter { public int getServerNum() { if (serverNum <= 0) { LOGGER.severe(Common.addTag("[flParameter] the parameter of <= 0, it should be > 0, please " + - "set it before use")); + "set it before using")); throw new IllegalArgumentException(); } return serverNum; @@ -296,11 +419,159 @@ public class FLParameter { this.serverNum = serverNum; } - public String getClientID() { - if (clientID == null || clientID.isEmpty()) { - LOGGER.severe(Common.addTag("[flParameter] the parameter of is null or empty, please check")); + public boolean isPkiVerify() { + return ifPkiVerify; + } + + public void setPkiVerify(boolean ifPkiVerify) { + this.ifPkiVerify = ifPkiVerify; + } + + public String getEquipCrlPath() { + if (equipCrlPath == null || equipCrlPath.isEmpty()) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null or empty, please set " + + "it before using")); throw new IllegalArgumentException(); } - return clientID; + return equipCrlPath; + } + + /** + * Obtains the Valid Iteration Interval set by a user. + * + * @return the Valid Iteration Interval. + */ + public long getValidInterval() { + if (validIterInterval <= 0) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is not valid, please set " + + "it as larger than 0.")); + throw new IllegalArgumentException(); + } + return validIterInterval; + } + + public void setEquipCrlPath(String certPath) { + String realCertPath = Common.getRealPath(certPath); + if (Common.checkPath(realCertPath)) { + this.equipCrlPath = realCertPath; + } else { + LOGGER.severe(Common.addTag("[flParameter] the parameter of does not exist, please check " + + "it before setting")); + throw new IllegalArgumentException(); + } + } + + /** + * Set the Valid Iteration Interval. + * + * @param validInterval the Valid Iteration Interval. + */ + public void setValidInterval(long validInterval) { + if (validInterval > 0) { + this.validIterInterval = validInterval; + } else { + LOGGER.severe(Common.addTag("[flParameter] the parameter of should be larger than 0, " + + "please set it again.")); + throw new IllegalArgumentException(); + } + } + + public int getThreadNum() { + return threadNum; + } + + public void setThreadNum(int threadNum) { + if (threadNum <= 0) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of <= 0, please check it before " + + "setting")); + throw new IllegalArgumentException(); + } + this.threadNum = threadNum; + } + + public int getCpuBindMode() { + LOGGER.info(Common.addTag("[flParameter] the parameter of is: " + cpuBindMode.toString() + " , " + + "the NOT_BINDING_CORE means that not binding core, BIND_LARGE_CORE means binding the large core, " + + "BIND_MIDDLE_CORE means binding the middle core")); + return cpuBindMode.ordinal(); + } + + public void setCpuBindMode(BindMode cpuBindMode) { + this.cpuBindMode = cpuBindMode; + } + + public void setHybridWeightName(List hybridWeightName, RunType runType) { + if (RunType.TRAINMODE.equals(runType)) { + this.trainWeightName = hybridWeightName; + } else if (RunType.INFERMODE.equals(runType)) { + this.inferWeightName = hybridWeightName; + } else { + LOGGER.severe(Common.addTag("[flParameter] the variable can only be set to " + + "or , please check it")); + throw new IllegalArgumentException(); + } + + } + + public List getHybridWeightName(RunType runType) { + if (RunType.TRAINMODE.equals(runType)) { + if (trainWeightName.isEmpty()) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please " + + "set it before use")); + throw new IllegalArgumentException(); + } + return trainWeightName; + } else if (RunType.INFERMODE.equals(runType)) { + if (inferWeightName.isEmpty()) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please " + + "set it before use")); + throw new IllegalArgumentException(); + } + return inferWeightName; + } else { + LOGGER.severe(Common.addTag("[flParameter] the variable can only be set to " + + "or , please check it")); + throw new IllegalArgumentException(); + } + + } + + public Map> getDataMap() { + if (dataMap.isEmpty()) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please " + + "set it before use")); + throw new IllegalArgumentException(); + } + return dataMap; + } + + public void setDataMap(Map> dataMap) { + this.dataMap = dataMap; + } + + public ServerMod getServerMod() { + if (serverMod == null) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please " + + "set it before use")); + throw new IllegalArgumentException(); + } + return serverMod; + } + + public void setServerMod(ServerMod serverMod) { + this.serverMod = serverMod; + } + + public int getBatchSize() { + return batchSize; + } + + public void setBatchSize(int batchSize) { + if (batchSize <= 0) { + LOGGER.severe(Common.addTag("[flParameter] the parameter of <= 0, please check it before " + + "setting")); + throw new IllegalArgumentException(); + } + this.batchSize = batchSize; } } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/GetModel.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/GetModel.java index 495abadf14e..71d6da76815 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/GetModel.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/GetModel.java @@ -16,16 +16,17 @@ package com.mindspore.flclient; -import static com.mindspore.flclient.LocalFLParameter.ALBERT; -import static com.mindspore.flclient.LocalFLParameter.LENET; - import com.google.flatbuffers.FlatBufferBuilder; import com.mindspore.flclient.model.AlInferBert; import com.mindspore.flclient.model.AlTrainBert; +import com.mindspore.flclient.model.Client; +import com.mindspore.flclient.model.ClientManager; +import com.mindspore.flclient.model.RunType; import com.mindspore.flclient.model.SessionUtil; -import com.mindspore.flclient.model.TrainLenet; +import com.mindspore.flclient.model.Status; +import com.mindspore.flclient.model.TrainLenet; import mindspore.schema.FeatureMap; import mindspore.schema.RequestGetModel; import mindspore.schema.ResponseCode; @@ -35,6 +36,9 @@ import java.util.ArrayList; import java.util.Date; import java.util.logging.Logger; +import static com.mindspore.flclient.LocalFLParameter.ALBERT; +import static com.mindspore.flclient.LocalFLParameter.LENET; + /** * Define the serialization method, handle the response message returned from server for getModel request. * @@ -50,6 +54,7 @@ public class GetModel { private FLParameter flParameter = FLParameter.getInstance(); private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); + private int retCode = ResponseCode.RequestError; private GetModel() { } @@ -72,6 +77,10 @@ public class GetModel { return localRef; } + public int getRetCode() { + return retCode; + } + /** * Get a flatBuffer builder of RequestGetModel. * @@ -88,8 +97,8 @@ public class GetModel { return builder.iteration(iteration).flName(name).time().build(); } - - private FLClientStatus parseResponseAlbert(ResponseGetModel responseDataBuf) { + private FLClientStatus deprecatedParseResponseAlbert(ResponseGetModel responseDataBuf) { + FLClientStatus status; int fmCount = responseDataBuf.featureMapLength(); if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { LOGGER.info(Common.addTag("[getModel] into ")); @@ -113,6 +122,11 @@ public class GetModel { LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength:" + " " + feature.dataLength())); } + status = Common.initSession(flParameter.getInferModelPath()); + if (status == FLClientStatus.FAILED) { + retCode = ResponseCode.RequestError; + return status; + } int tag = 0; LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference " + "model-----------------")); @@ -131,6 +145,7 @@ public class GetModel { LOGGER.severe(Common.addTag("[getModel] unsolved error code in ")); return FLClientStatus.FAILED; } + Common.freeSession(); } else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) { LOGGER.info(Common.addTag("[getModel] into ")); ArrayList featureMaps = new ArrayList(); @@ -145,6 +160,11 @@ public class GetModel { LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " + feature.dataLength())); } + status = Common.initSession(flParameter.getInferModelPath()); + if (status == FLClientStatus.FAILED) { + retCode = ResponseCode.RequestError; + return status; + } int tag = 0; LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------")); AlTrainBert alTrainBert = AlTrainBert.getInstance(); @@ -154,11 +174,13 @@ public class GetModel { LOGGER.severe(Common.addTag("[getModel] unsolved error code in ")); return FLClientStatus.FAILED; } + Common.freeSession(); } return FLClientStatus.SUCCESS; } - private FLClientStatus parseResponseLenet(ResponseGetModel responseDataBuf) { + private FLClientStatus deprecatedParseResponseLenet(ResponseGetModel responseDataBuf) { + FLClientStatus status; int fmCount = responseDataBuf.featureMapLength(); ArrayList featureMaps = new ArrayList(); for (int i = 0; i < fmCount; i++) { @@ -172,6 +194,11 @@ public class GetModel { LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " + feature.dataLength())); } + status = Common.initSession(flParameter.getInferModelPath()); + if (status == FLClientStatus.FAILED) { + retCode = ResponseCode.RequestError; + return status; + } int tag = 0; LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------")); TrainLenet trainLenet = TrainLenet.getInstance(); @@ -180,6 +207,116 @@ public class GetModel { LOGGER.severe(Common.addTag("[getModel] unsolved error code in ")); return FLClientStatus.FAILED; } + Common.freeSession(); + return FLClientStatus.SUCCESS; + } + + private FLClientStatus deprecatedParseFeatures(ResponseGetModel responseDataBuf) { + FLClientStatus status = FLClientStatus.SUCCESS; + if (ALBERT.equals(flParameter.getFlName())) { + LOGGER.info(Common.addTag("[getModel] into ")); + status = deprecatedParseResponseAlbert(responseDataBuf); + } else if (LENET.equals(flParameter.getFlName())) { + LOGGER.info(Common.addTag("[getModel] into ")); + status = deprecatedParseResponseLenet(responseDataBuf); + } else { + LOGGER.severe(Common.addTag("[getModel] the flName is not valid, only support: lenet, albert")); + status = FLClientStatus.FAILED; + } + return status; + } + + private FLClientStatus parseResponseFeatures(ResponseGetModel responseDataBuf) { + FLClientStatus status; + Client client = ClientManager.getClient(flParameter.getFlName()); + int fmCount = responseDataBuf.featureMapLength(); + if (fmCount <= 0) { + LOGGER.severe(Common.addTag("[getModel] the feature size get from server is zero")); + retCode = ResponseCode.SystemError; + return FLClientStatus.FAILED; + } + if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { + LOGGER.info(Common.addTag("[getModel] parseResponseFeatures by " + localFLParameter.getServerMod())); + ArrayList trainFeatureMaps = new ArrayList(); + ArrayList inferFeatureMaps = new ArrayList(); + for (int i = 0; i < fmCount; i++) { + FeatureMap feature = responseDataBuf.featureMap(i); + if (feature == null) { + LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null")); + retCode = ResponseCode.SystemError; + return FLClientStatus.FAILED; + } + String featureName = feature.weightFullname(); + if (flParameter.getHybridWeightName(RunType.TRAINMODE).contains(featureName)) { + trainFeatureMaps.add(feature); + LOGGER.info(Common.addTag("[getModel] trainWeightFullname: " + feature.weightFullname() + ", " + + "trainWeightLength: " + feature.dataLength())); + } + if (flParameter.getHybridWeightName(RunType.INFERMODE).contains(featureName)) { + inferFeatureMaps.add(feature); + LOGGER.info(Common.addTag("[getModel] inferWeightFullname: " + feature.weightFullname() + ", " + + "inferWeightLength: " + feature.dataLength())); + } + } + Status tag; + LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference " + + "model-----------------")); + status = Common.initSession(flParameter.getInferModelPath()); + if (status == FLClientStatus.FAILED) { + retCode = ResponseCode.RequestError; + return status; + } + tag = client.updateFeatures(flParameter.getInferModelPath(), inferFeatureMaps); + Common.freeSession(); + if (!Status.SUCCESS.equals(tag)) { + LOGGER.severe(Common.addTag("[getModel] unsolved error code in ")); + retCode = ResponseCode.RequestError; + return FLClientStatus.FAILED; + } + LOGGER.info(Common.addTag("[getModel] ----------------loading weight into train model-----------------")); + status = Common.initSession(flParameter.getTrainModelPath()); + if (status == FLClientStatus.FAILED) { + retCode = ResponseCode.RequestError; + return status; + } + tag = client.updateFeatures(flParameter.getTrainModelPath(), inferFeatureMaps); + Common.freeSession(); + if (!Status.SUCCESS.equals(tag)) { + LOGGER.severe(Common.addTag("[getModel] unsolved error code in ")); + retCode = ResponseCode.RequestError; + return FLClientStatus.FAILED; + } + } else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) { + LOGGER.info(Common.addTag("[getModel] parseResponseFeatures by " + localFLParameter.getServerMod())); + ArrayList featureMaps = new ArrayList(); + for (int i = 0; i < fmCount; i++) { + FeatureMap feature = responseDataBuf.featureMap(i); + if (feature == null) { + LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null")); + retCode = ResponseCode.SystemError; + return FLClientStatus.FAILED; + } + String featureName = feature.weightFullname(); + featureMaps.add(feature); + LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", " + + "weightLength: " + feature.dataLength())); + } + Status tag; + LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------")); + status = Common.initSession(flParameter.getTrainModelPath()); + if (status == FLClientStatus.FAILED) { + retCode = ResponseCode.RequestError; + return status; + } + tag = client.updateFeatures(flParameter.getTrainModelPath(), featureMaps); + LOGGER.info(Common.addTag("[getModel] ===========free session=============")); + Common.freeSession(); + if (!Status.SUCCESS.equals(tag)) { + LOGGER.severe(Common.addTag("[getModel] unsolved error code in ")); + retCode = ResponseCode.RequestError; + return FLClientStatus.FAILED; + } + } return FLClientStatus.SUCCESS; } @@ -190,25 +327,21 @@ public class GetModel { * @return the status code corresponding to the response message. */ public FLClientStatus doResponse(ResponseGetModel responseDataBuf) { - LOGGER.info(Common.addTag("[getModel] ==========get model content is:================")); - LOGGER.info(Common.addTag("[getModel] ==========retCode: " + responseDataBuf.retcode())); + retCode = responseDataBuf.retcode(); + LOGGER.info(Common.addTag("[getModel] ==========the response message of getModel is:================")); + LOGGER.info(Common.addTag("[getModel] ==========retCode: " + retCode)); LOGGER.info(Common.addTag("[getModel] ==========reason: " + responseDataBuf.reason())); LOGGER.info(Common.addTag("[getModel] ==========iteration: " + responseDataBuf.iteration())); LOGGER.info(Common.addTag("[getModel] ==========time: " + responseDataBuf.timestamp())); FLClientStatus status = FLClientStatus.SUCCESS; - int retCode = responseDataBuf.retcode(); - switch (retCode) { + switch (responseDataBuf.retcode()) { case (ResponseCode.SUCCEED): LOGGER.info(Common.addTag("[getModel] getModel response success")); - if (ALBERT.equals(flParameter.getFlName())) { - LOGGER.info(Common.addTag("[getModel] into ")); - status = parseResponseAlbert(responseDataBuf); - } else if (LENET.equals(flParameter.getFlName())) { - LOGGER.info(Common.addTag("[getModel] into ")); - status = parseResponseLenet(responseDataBuf); + if (Common.checkFLName(flParameter.getFlName())) { + status = deprecatedParseFeatures(responseDataBuf); } else { - LOGGER.severe(Common.addTag("[getModel] the flName is not valid, only support: lenet, albert")); - throw new IllegalArgumentException(); + LOGGER.info(Common.addTag("[getModel] into ")); + status = parseResponseFeatures(responseDataBuf); } return status; case (ResponseCode.SucNotReady): diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java index f1f30c899c4..59a6cbc4d3b 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java @@ -16,6 +16,8 @@ package com.mindspore.flclient; +import com.mindspore.lite.config.MSConfig; + import org.bouncycastle.math.ec.rfc7748.X25519; import java.util.ArrayList; @@ -59,6 +61,16 @@ public class LocalFLParameter { * The model name supported by federated learning tasks: "albert". */ public static final String ALBERT = "albert"; + + /** + * The deployment environment supported by federated learning tasks: "android". + */ + public static final String ANDROID = "android"; + + /** + * The deployment environment supported by federated learning tasks: "x86". + */ + public static final String X86 = "x86"; private static volatile LocalFLParameter localFLParameter; private List classifierWeightName = new ArrayList<>(); @@ -67,6 +79,9 @@ public class LocalFLParameter { private String encryptLevel = EncryptLevel.NOT_ENCRYPT.toString(); private String earlyStopMod = EarlyStopMod.NOT_EARLY_STOP.toString(); private String serverMod = ServerMod.HYBRID_TRAINING.toString(); + private boolean stopJobFlag = false; + private MSConfig msConfig = new MSConfig(); + private boolean useSSL = true; private LocalFLParameter() { // set classifierWeightName albertWeightName @@ -143,14 +158,14 @@ public class LocalFLParameter { public void setEncryptLevel(String encryptLevel) { if (encryptLevel == null || encryptLevel.isEmpty()) { LOGGER.severe(Common.addTag("[localFLParameter] the parameter of is null, please check it " + - "before set")); + "before setting")); throw new IllegalArgumentException(); } if ((!EncryptLevel.DP_ENCRYPT.toString().equals(encryptLevel)) && (!EncryptLevel.NOT_ENCRYPT.toString().equals(encryptLevel)) && (!EncryptLevel.PW_ENCRYPT.toString().equals(encryptLevel))) { LOGGER.severe(Common.addTag("[localFLParameter] the parameter of is " + encryptLevel + " ," + - " it must be DP_ENCRYPT or NOT_ENCRYPT or PW_ENCRYPT, please check it before set")); + " it must be DP_ENCRYPT or NOT_ENCRYPT or PW_ENCRYPT, please check it before setting")); throw new IllegalArgumentException(); } this.encryptLevel = encryptLevel; @@ -163,7 +178,7 @@ public class LocalFLParameter { public void setEarlyStopMod(String earlyStopMod) { if (earlyStopMod == null || earlyStopMod.isEmpty()) { LOGGER.severe(Common.addTag("[localFLParameter] the parameter of is null, please check it " + - "before set")); + "before setting")); throw new IllegalArgumentException(); } if ((!EarlyStopMod.NOT_EARLY_STOP.toString().equals(earlyStopMod)) && @@ -171,7 +186,8 @@ public class LocalFLParameter { (!EarlyStopMod.LOSS_DIFF.toString().equals(earlyStopMod)) && (!EarlyStopMod.WEIGHT_DIFF.toString().equals(earlyStopMod))) { LOGGER.severe(Common.addTag("[localFLParameter] the parameter of is " + earlyStopMod + " ," + - " it must be NOT_EARLY_STOP or LOSS_ABS or LOSS_DIFF or WEIGHT_DIFF, please check it before set")); + " it must be NOT_EARLY_STOP or LOSS_ABS or LOSS_DIFF or WEIGHT_DIFF, please check it before " + + "setting")); throw new IllegalArgumentException(); } this.earlyStopMod = earlyStopMod; @@ -184,15 +200,44 @@ public class LocalFLParameter { public void setServerMod(String serverMod) { if (serverMod == null || serverMod.isEmpty()) { LOGGER.severe(Common.addTag("[localFLParameter] the parameter of is null, please check it " + - "before set")); + "before setting")); throw new IllegalArgumentException(); } if ((!ServerMod.HYBRID_TRAINING.toString().equals(serverMod)) && (!ServerMod.FEDERATED_LEARNING.toString().equals(serverMod))) { LOGGER.severe(Common.addTag("[localFLParameter] the parameter of is " + serverMod + " , it " + - "must be HYBRID_TRAINING or FEDERATED_LEARNING, please check it before set")); + "must be HYBRID_TRAINING or FEDERATED_LEARNING, please check it before setting")); throw new IllegalArgumentException(); } this.serverMod = serverMod; } + + public boolean isStopJobFlag() { + return stopJobFlag; + } + + public void setStopJobFlag(boolean stopJobFlag) { + this.stopJobFlag = stopJobFlag; + } + + public MSConfig getMsConfig() { + return msConfig; + } + + public void setMsConfig(int DeviceType, int threadNum, int cpuBindMode, boolean enable_fp16) { + // arg 0: DeviceType:DT_CPU -> 0 + // arg 1: ThreadNum -> 2 + // arg 2: cpuBindMode:NO_BIND -> 0 + // arg 3: enable_fp16 -> false + msConfig.init(DeviceType, threadNum, cpuBindMode, enable_fp16); + } + + + public boolean isUseSSL() { + return useSSL; + } + + public void setUseSSL(boolean useSSL) { + this.useSSL = useSSL; + } } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SSLSocketFactoryTools.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SSLSocketFactoryTools.java index 5096b732772..aa853e7aaef 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SSLSocketFactoryTools.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SSLSocketFactoryTools.java @@ -68,7 +68,7 @@ public class SSLSocketFactoryTools { String domainName = flParameter.getDomainName(); if ((domainName == null || domainName.isEmpty() || domainName.split("//").length < 2)) { LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the is null or not valid, it should" + - " be like as https://...... , please check!")); + " be like https://...... , please check!")); throw new IllegalArgumentException(); } if (domainName.split("//")[1].split(":").length < 2) { @@ -87,7 +87,7 @@ public class SSLSocketFactoryTools { private void initSslSocketFactory() { try { - sslContext = SSLContext.getInstance("TLS"); + sslContext = SSLContext.getInstance(flParameter.getSslProtocol()); x509Certificate = readCert(flParameter.getCertPath()); myTrustManager = new MyTrustManager(x509Certificate); sslContext.init(null, new TrustManager[]{ @@ -137,9 +137,8 @@ public class SSLSocketFactoryTools { "convert to X509Certificate")); } } catch (FileNotFoundException | CertificateException ex) { - LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch FileNotFoundException or CertificateException " + - "when creating " + - "CertificateFactory in readCert: " + ex.getMessage())); + LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch exception when creating CertificateFactory in " + + "readCert: invalid file or CertificateException")); } finally { try { if (inputStream != null) { diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java index a3c6509d778..38c44ad4d84 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java @@ -16,20 +16,12 @@ package com.mindspore.flclient; -import static com.mindspore.flclient.LocalFLParameter.ALBERT; -import static com.mindspore.flclient.LocalFLParameter.LENET; - import com.google.flatbuffers.FlatBufferBuilder; -import com.mindspore.flclient.model.AlTrainBert; -import com.mindspore.flclient.model.SessionUtil; -import com.mindspore.flclient.model.TrainLenet; - import mindspore.schema.FeatureMap; import java.security.SecureRandom; import java.util.ArrayList; -import java.util.HashMap; import java.util.Map; import java.util.logging.Logger; @@ -52,7 +44,7 @@ public class SecureProtocol { private double dpEps; private double dpDelta; private double dpNormClip; - private ArrayList encryptFeatureName = new ArrayList(); + private ArrayList updateFeatureName = new ArrayList(); private int retCode; /** @@ -115,17 +107,17 @@ public class SecureProtocol { * * @return the feature names that needed to be encrypted. */ - public ArrayList getEncryptFeatureName() { - return encryptFeatureName; + public ArrayList getUpdateFeatureName() { + return updateFeatureName; } /** - * Set the parameter encryptFeatureName. + * Set the parameter updateFeatureName. * - * @param encryptFeatureName the feature names that needed to be encrypted. + * @param updateFeatureName the feature names that needed to be encrypted. */ - public void setEncryptFeatureName(ArrayList encryptFeatureName) { - this.encryptFeatureName = encryptFeatureName; + public void setUpdateFeatureName(ArrayList updateFeatureName) { + this.updateFeatureName = updateFeatureName; } /** @@ -146,6 +138,10 @@ public class SecureProtocol { LOGGER.info(String.format("[PairWiseMask] ==============request flID: %s ==============", localFLParameter.getFlID())); // round 0 + if (localFLParameter.isStopJobFlag()) { + LOGGER.info(Common.addTag("the stopJObFlag is set to true, the job will be stop")); + return status; + } status = cipherClient.exchangeKeys(); retCode = cipherClient.getRetCode(); LOGGER.info(String.format("[PairWiseMask] ============= RequestExchangeKeys+GetExchangeKeys response: %s ", @@ -154,6 +150,10 @@ public class SecureProtocol { return status; } // round 1 + if (localFLParameter.isStopJobFlag()) { + LOGGER.info(Common.addTag("the stopJObFlag is set to true, the job will be stop")); + return status; + } status = cipherClient.shareSecrets(); retCode = cipherClient.getRetCode(); LOGGER.info(String.format("[Encrypt] =============RequestShareSecrets+GetShareSecrets response: %s ", @@ -162,6 +162,10 @@ public class SecureProtocol { return status; } // round2 + if (localFLParameter.isStopJobFlag()) { + LOGGER.info(Common.addTag("the stopJObFlag is set to true, the job will be stop")); + return status; + } featureMask = cipherClient.doubleMaskingWeight(); if (featureMask == null || featureMask.length <= 0) { LOGGER.severe(Common.addTag("[Encrypt] the returned featureMask from cipherClient.doubleMaskingWeight" + @@ -180,30 +184,18 @@ public class SecureProtocol { * @param trainDataSize trainDataSize tne size of train data set. * @return the serialized model weights after adding masks. */ - public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize) { + public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize, Map trainedMap) { if (featureMask == null || featureMask.length == 0) { LOGGER.severe("[Encrypt] feature mask is null, please check"); return new int[0]; } LOGGER.info(String.format("[Encrypt] feature mask size: %s", featureMask.length)); - // get feature map - Map map = new HashMap(); - if (flParameter.getFlName().equals(ALBERT)) { - AlTrainBert alTrainBert = AlTrainBert.getInstance(); - map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession())); - } else if (flParameter.getFlName().equals(LENET)) { - TrainLenet trainLenet = TrainLenet.getInstance(); - map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); - } else { - LOGGER.severe(Common.addTag("[Encrypt] the flName is not valid, only support: lenet, albert")); - throw new IllegalArgumentException(); - } - int featureSize = encryptFeatureName.size(); + int featureSize = updateFeatureName.size(); int[] featuresMap = new int[featureSize]; int maskIndex = 0; for (int i = 0; i < featureSize; i++) { - String key = encryptFeatureName.get(i); - float[] data = map.get(key); + String key = updateFeatureName.get(i); + float[] data = trainedMap.get(key); LOGGER.info(String.format("[Encrypt] feature name: %s feature size: %s", key, data.length)); for (int j = 0; j < data.length; j++) { float rawData = data[j]; @@ -349,21 +341,10 @@ public class SecureProtocol { * @param trainDataSize tne size of train data set. * @return the serialized model weights after adding masks. */ - public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize) { + public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize, Map trainedMap) { // get feature map - Map map = new HashMap(); - if (flParameter.getFlName().equals(ALBERT)) { - AlTrainBert alTrainBert = AlTrainBert.getInstance(); - map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession())); - } else if (flParameter.getFlName().equals(LENET)) { - TrainLenet trainLenet = TrainLenet.getInstance(); - map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); - } else { - LOGGER.severe(Common.addTag("[Encrypt] the flName is not valid, only support: lenet, albert")); - throw new IllegalArgumentException(); - } Map mapBeforeTrain = modelMap; - int featureSize = encryptFeatureName.size(); + int featureSize = updateFeatureName.size(); // calculate sigma double gaussianSigma = calculateSigma(dpNormClip, dpEps, dpDelta); LOGGER.info(Common.addTag("[Encrypt] =============Noise sigma of DP is: " + gaussianSigma + "=============")); @@ -371,8 +352,8 @@ public class SecureProtocol { // calculate l2-norm of all layers' update array double updateL2Norm = 0d; for (int i = 0; i < featureSize; i++) { - String key = encryptFeatureName.get(i); - float[] data = map.get(key); + String key = updateFeatureName.get(i); + float[] data = trainedMap.get(key); float[] dataBeforeTrain = mapBeforeTrain.get(key); for (int j = 0; j < data.length; j++) { float rawData = data[j]; @@ -391,12 +372,12 @@ public class SecureProtocol { // clip and add noise int[] featuresMap = new int[featureSize]; for (int i = 0; i < featureSize; i++) { - String key = encryptFeatureName.get(i); - if (!map.containsKey(key)) { + String key = updateFeatureName.get(i); + if (!trainedMap.containsKey(key)) { LOGGER.severe("[Encrypt] the key: " + key + " is not in map, please check!"); return new int[0]; } - float[] data = map.get(key); + float[] data = trainedMap.get(key); float[] data2 = new float[data.length]; if (!mapBeforeTrain.containsKey(key)) { LOGGER.severe("[Encrypt] the key: " + key + " is not in mapBeforeTrain, please check!"); diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/StartFLJob.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/StartFLJob.java index 1333d7def24..c1d4020ae3b 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/StartFLJob.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/StartFLJob.java @@ -1,29 +1,21 @@ /* * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ package com.mindspore.flclient; -import static com.mindspore.flclient.LocalFLParameter.ALBERT; -import static com.mindspore.flclient.LocalFLParameter.LENET; - import com.google.flatbuffers.FlatBufferBuilder; import com.mindspore.flclient.model.AlInferBert; import com.mindspore.flclient.model.AlTrainBert; +import com.mindspore.flclient.model.Client; +import com.mindspore.flclient.model.ClientManager; +import com.mindspore.flclient.model.RunType; import com.mindspore.flclient.model.SessionUtil; +import com.mindspore.flclient.model.Status; import com.mindspore.flclient.model.TrainLenet; +import com.mindspore.flclient.pki.PkiBean; +import com.mindspore.flclient.pki.PkiUtil; import mindspore.schema.FLPlan; import mindspore.schema.FeatureMap; @@ -31,9 +23,14 @@ import mindspore.schema.RequestFLJob; import mindspore.schema.ResponseCode; import mindspore.schema.ResponseFLJob; +import java.io.IOException; +import java.security.cert.Certificate; import java.util.ArrayList; import java.util.logging.Logger; +import static com.mindspore.flclient.LocalFLParameter.ALBERT; +import static com.mindspore.flclient.LocalFLParameter.LENET; + /** * StartFLJob * @@ -52,11 +49,425 @@ public class StartFLJob { private int featureSize; private String nextRequestTime; - private ArrayList encryptFeatureName = new ArrayList(); + private ArrayList updateFeatureName = new ArrayList(); + private int retCode = ResponseCode.RequestError; + private float lr = (float) 0.1; + private int batchSize; private StartFLJob() { } + /** + * getInstance of StartFLJob + * + * @return StartFLJob instance + */ + public static StartFLJob getInstance() { + StartFLJob localRef = startFLJob; + if (localRef == null) { + synchronized (StartFLJob.class) { + localRef = startFLJob; + if (localRef == null) { + startFLJob = localRef = new StartFLJob(); + } + } + } + return localRef; + } + + public String getNextRequestTime() { + return nextRequestTime; + } + + public int getRetCode() { + return retCode; + } + + /** + * get request start FLJob + * + * @param dataSize dataSize + * @param iteration iteration + * @param time time + * @param pkiBean pki bean + * @return byte[] data + */ + public byte[] getRequestStartFLJob(int dataSize, int iteration, long time, PkiBean pkiBean) { + RequestStartFLJobBuilder builder = new RequestStartFLJobBuilder(); + + if (flParameter.isPkiVerify()) { + if (pkiBean == null) { + LOGGER.severe(Common.addTag("[startFLJob] the parameter of is null, please check!")); + throw new IllegalArgumentException(); + } + return builder.flName(flParameter.getFlName()) + .time(time) + .id(localFLParameter.getFlID()) + .dataSize(dataSize) + .iteration(iteration) + .signData(pkiBean.getSignData()) + .certificateChain(pkiBean.getCertificates()) + .build(); + } + return builder.flName(flParameter.getFlName()) + .time(time) + .id(localFLParameter.getFlID()) + .dataSize(dataSize) + .iteration(iteration) + .build(); + } + + public int getFeatureSize() { + return featureSize; + } + + public ArrayList getUpdateFeatureName() { + return updateFeatureName; + } + + private FLClientStatus deprecatedParseResponseAlbert(ResponseFLJob flJob) { + FLClientStatus status; + int fmCount = flJob.featureMapLength(); + updateFeatureName.clear(); + if (fmCount <= 0) { + LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero")); + return FLClientStatus.FAILED; + } + + if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { + LOGGER.info(Common.addTag("[startFLJob] parseResponseAlbert by " + localFLParameter.getServerMod())); + ArrayList albertFeatureMaps = new ArrayList(); + ArrayList inferFeatureMaps = new ArrayList(); + for (int i = 0; i < fmCount; i++) { + FeatureMap feature = flJob.featureMap(i); + if (feature == null) { + LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null")); + return FLClientStatus.FAILED; + } + String featureName = feature.weightFullname(); + if (localFLParameter.getAlbertWeightName().contains(featureName)) { + albertFeatureMaps.add(feature); + inferFeatureMaps.add(feature); + featureSize += feature.dataLength(); + updateFeatureName.add(feature.weightFullname()); + } else if (localFLParameter.getClassifierWeightName().contains(featureName)) { + inferFeatureMaps.add(feature); + } else { + continue; + } + LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " + + "weightLength: " + feature.dataLength())); + } + status = Common.initSession(flParameter.getTrainModelPath()); + if (status == FLClientStatus.FAILED) { + retCode = ResponseCode.RequestError; + return status; + } + int tag = 0; + LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference " + + "model-----------------")); + AlInferBert alInferBert = AlInferBert.getInstance(); + tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(), + inferFeatureMaps); + if (tag == -1) { + LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in ")); + return FLClientStatus.FAILED; + } + LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------")); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), + albertFeatureMaps); + if (tag == -1) { + LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in ")); + return FLClientStatus.FAILED; + } + Common.freeSession(); + } else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) { + LOGGER.info(Common.addTag("[startFLJob] parseResponseAlbert by " + localFLParameter.getServerMod())); + ArrayList featureMaps = new ArrayList(); + for (int i = 0; i < fmCount; i++) { + FeatureMap feature = flJob.featureMap(i); + if (feature == null) { + LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null")); + return FLClientStatus.FAILED; + } + String featureName = feature.weightFullname(); + featureMaps.add(feature); + featureSize += feature.dataLength(); + updateFeatureName.add(featureName); + LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " + + "weightLength: " + feature.dataLength())); + } + status = Common.initSession(flParameter.getTrainModelPath()); + if (status == FLClientStatus.FAILED) { + retCode = ResponseCode.RequestError; + return status; + } + int tag = 0; + LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------")); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), + featureMaps); + if (tag == -1) { + LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in ")); + return FLClientStatus.FAILED; + } + Common.freeSession(); + } + return FLClientStatus.SUCCESS; + } + + private FLClientStatus deprecatedParseResponseLenet(ResponseFLJob flJob) { + FLClientStatus status; + int fmCount = flJob.featureMapLength(); + ArrayList featureMaps = new ArrayList(); + updateFeatureName.clear(); + for (int i = 0; i < fmCount; i++) { + FeatureMap feature = flJob.featureMap(i); + if (feature == null) { + LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null")); + return FLClientStatus.FAILED; + } + String featureName = feature.weightFullname(); + featureMaps.add(feature); + featureSize += feature.dataLength(); + updateFeatureName.add(featureName); + LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + + feature.weightFullname() + ", weightLength: " + feature.dataLength())); + } + status = Common.initSession(flParameter.getTrainModelPath()); + if (status == FLClientStatus.FAILED) { + retCode = ResponseCode.RequestError; + return status; + } + int tag = 0; + LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------")); + TrainLenet trainLenet = TrainLenet.getInstance(); + tag = SessionUtil.updateFeatures(trainLenet.getTrainSession(), flParameter.getTrainModelPath(), featureMaps); + if (tag == -1) { + LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in ")); + return FLClientStatus.FAILED; + } + Common.freeSession(); + return FLClientStatus.SUCCESS; + } + + + private FLClientStatus hybridFeatures(ResponseFLJob flJob) { + FLClientStatus status; + Client client = ClientManager.getClient(flParameter.getFlName()); + int fmCount = flJob.featureMapLength(); + ArrayList trainFeatureMaps = new ArrayList(); + ArrayList inferFeatureMaps = new ArrayList(); + for (int i = 0; i < fmCount; i++) { + FeatureMap feature = flJob.featureMap(i); + if (feature == null) { + LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null")); + retCode = ResponseCode.SystemError; + return FLClientStatus.FAILED; + } + String featureName = feature.weightFullname(); + if (flParameter.getHybridWeightName(RunType.TRAINMODE).contains(featureName)) { + trainFeatureMaps.add(feature); + featureSize += feature.dataLength(); + updateFeatureName.add(feature.weightFullname()); + LOGGER.info(Common.addTag("[startFLJob] trainWeightFullname: " + feature.weightFullname() + ", " + + "trainWeightLength: " + feature.dataLength())); + } + if (flParameter.getHybridWeightName(RunType.INFERMODE).contains(featureName)) { + inferFeatureMaps.add(feature); + LOGGER.info(Common.addTag("[startFLJob] inferWeightFullname: " + feature.weightFullname() + ", " + + "inferWeightLength: " + feature.dataLength())); + } + } + Status tag; + LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference " + + "model-----------------")); + status = Common.initSession(flParameter.getInferModelPath()); + if (status == FLClientStatus.FAILED) { + retCode = ResponseCode.RequestError; + return status; + } + tag = client.updateFeatures(flParameter.getInferModelPath(), inferFeatureMaps); + Common.freeSession(); + if (!Status.SUCCESS.equals(tag)) { + LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in ")); + retCode = ResponseCode.RequestError; + return FLClientStatus.FAILED; + } + LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------")); + status = Common.initSession(flParameter.getTrainModelPath()); + if (status == FLClientStatus.FAILED) { + retCode = ResponseCode.RequestError; + return status; + } + LOGGER.info(Common.addTag("[startFLJob] set for client: " + lr)); + tag = client.setLearningRate(lr); + if (!Status.SUCCESS.equals(tag)) { + LOGGER.severe(Common.addTag("[startFLJob] setLearningRate failed, return -1, please check")); + retCode = ResponseCode.RequestError; + return FLClientStatus.FAILED; + } + LOGGER.info(Common.addTag("[startFLJob] set for client: " + batchSize)); + client.setBatchSize(batchSize); + tag = client.updateFeatures(flParameter.getTrainModelPath(), inferFeatureMaps); + Common.freeSession(); + if (!Status.SUCCESS.equals(tag)) { + LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in ")); + retCode = ResponseCode.RequestError; + return FLClientStatus.FAILED; + } + return status; + } + + private FLClientStatus normalFeatures(ResponseFLJob flJob) { + FLClientStatus status; + Client client = ClientManager.getClient(flParameter.getFlName()); + int fmCount = flJob.featureMapLength(); + ArrayList featureMaps = new ArrayList(); + for (int i = 0; i < fmCount; i++) { + FeatureMap feature = flJob.featureMap(i); + if (feature == null) { + LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null")); + retCode = ResponseCode.SystemError; + return FLClientStatus.FAILED; + } + String featureName = feature.weightFullname(); + featureMaps.add(feature); + featureSize += feature.dataLength(); + updateFeatureName.add(featureName); + LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " + + "weightLength: " + feature.dataLength())); + } + Status tag; + LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------")); + status = Common.initSession(flParameter.getTrainModelPath()); + if (status == FLClientStatus.FAILED) { + retCode = ResponseCode.RequestError; + return status; + } + LOGGER.info(Common.addTag("[startFLJob] set for client: " + lr)); + tag = client.setLearningRate(lr); + if (!Status.SUCCESS.equals(tag)) { + LOGGER.severe(Common.addTag("[startFLJob] setLearningRate failed, return -1, please check")); + retCode = ResponseCode.RequestError; + return FLClientStatus.FAILED; + } + LOGGER.info(Common.addTag("[startFLJob] set for client: " + batchSize)); + client.setBatchSize(batchSize); + tag = client.updateFeatures(flParameter.getTrainModelPath(), featureMaps); + LOGGER.info(Common.addTag("[startFLJob] ===========free session=============")); + Common.freeSession(); + if (!Status.SUCCESS.equals(tag)) { + LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in ")); + retCode = ResponseCode.RequestError; + return FLClientStatus.FAILED; + } + return status; + } + + private FLClientStatus parseResponseFeatures(ResponseFLJob flJob) { + FLClientStatus status; + int fmCount = flJob.featureMapLength(); + updateFeatureName.clear(); + + if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { + LOGGER.info(Common.addTag("[startFLJob] parseResponseFeatures by " + localFLParameter.getServerMod())); + status = hybridFeatures(flJob); + if (status == FLClientStatus.FAILED) { + return status; + } + } else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) { + LOGGER.info(Common.addTag("[startFLJob] parseResponseFeatures by " + localFLParameter.getServerMod())); + status = normalFeatures(flJob); + if (status == FLClientStatus.FAILED) { + return status; + } + } + return FLClientStatus.SUCCESS; + } + + private FLClientStatus deprecatedParseFeatures(ResponseFLJob flJob) { + FLClientStatus status = FLClientStatus.SUCCESS; + if (ALBERT.equals(flParameter.getFlName())) { + LOGGER.info(Common.addTag("[startFLJob] into ")); + status = deprecatedParseResponseAlbert(flJob); + } + if (LENET.equals(flParameter.getFlName())) { + LOGGER.info(Common.addTag("[startFLJob] into ")); + status = deprecatedParseResponseLenet(flJob); + } + return status; + } + + + /** + * response res + * + * @param flJob ResponseFLJob + * @return FLClientStatus + */ + public FLClientStatus doResponse(ResponseFLJob flJob) { + if (flJob == null) { + LOGGER.severe(Common.addTag("[startFLJob] the input parameter flJob is null")); + retCode = ResponseCode.SystemError; + return FLClientStatus.FAILED; + } + FLPlan flPlanConfig = flJob.flPlanConfig(); + if (flPlanConfig == null) { + LOGGER.severe(Common.addTag("[startFLJob] the flPlanConfig is null")); + retCode = ResponseCode.SystemError; + return FLClientStatus.FAILED; + } + + if (flJob.featureMapLength() <= 0) { + LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero")); + retCode = ResponseCode.SystemError; + return FLClientStatus.FAILED; + } + + retCode = flJob.retcode(); + LOGGER.info(Common.addTag("[startFLJob] ==========the response message of startFLJob is:================")); + LOGGER.info(Common.addTag("[startFLJob] return retCode: " + retCode)); + LOGGER.info(Common.addTag("[startFLJob] reason: " + flJob.reason())); + LOGGER.info(Common.addTag("[startFLJob] iteration: " + flJob.iteration())); + LOGGER.info(Common.addTag("[startFLJob] is selected: " + flJob.isSelected())); + LOGGER.info(Common.addTag("[startFLJob] next request time: " + flJob.nextReqTime())); + nextRequestTime = flJob.nextReqTime(); + LOGGER.info(Common.addTag("[startFLJob] timestamp: " + flJob.timestamp())); + FLClientStatus status; + int responseRetCode = flJob.retcode(); + + switch (responseRetCode) { + case (ResponseCode.SUCCEED): + localFLParameter.setServerMod(flPlanConfig.serverMode()); + if (flPlanConfig.lr() != 0) { + lr = flPlanConfig.lr(); + } else { + LOGGER.info(Common.addTag("[startFLJob] the GlobalParameter from server: " + lr + " is not " + + "valid, " + + "will use the default value 0.1")); + } + batchSize = flPlanConfig.miniBatch(); + if (Common.checkFLName(flParameter.getFlName())) { + status = deprecatedParseFeatures(flJob); + } else { + LOGGER.info(Common.addTag("[startFLJob] into ")); + status = parseResponseFeatures(flJob); + } + return status; + case (ResponseCode.OutOfTime): + return FLClientStatus.RESTART; + case (ResponseCode.RequestError): + case (ResponseCode.SystemError): + LOGGER.info(Common.addTag("[startFLJob] catch RequestError or SystemError")); + return FLClientStatus.FAILED; + default: + LOGGER.severe(Common.addTag("[startFLJob] the return from server is invalid: " + retCode)); + return FLClientStatus.FAILED; + } + } + class RequestStartFLJobBuilder { private RequestFLJob requestFLJob; private FlatBufferBuilder builder; @@ -65,6 +476,11 @@ public class StartFLJob { private int dataSize = 0; private int timestampOffset = 0; private int idOffset = 0; + private int signDataOffset = 0; + private int keyAttestationOffset = 0; + private int equipCertOffset = 0; + private int equipCACertOffset = 0; + private int rootCertOffset = 0; public RequestStartFLJobBuilder() { builder = new FlatBufferBuilder(); @@ -135,6 +551,50 @@ public class StartFLJob { return this; } + /** + * signData + * + * @param signData byte[] + * @return RequestStartFLJobBuilder + */ + public RequestStartFLJobBuilder signData(byte[] signData) { + if (signData == null || signData.length == 0) { + LOGGER.severe( + Common.addTag("[startFLJob] the parameter of is null or empty, please check!")); + throw new IllegalArgumentException(); + } + this.signDataOffset = RequestFLJob.createSignDataVector(builder, signData); + return this; + } + + /** + * set certificateChain + * + * @param certificates Certificate array + * @return RequestStartFLJobBuilder + */ + public RequestStartFLJobBuilder certificateChain(Certificate[] certificates) { + if (certificates == null || certificates.length < 4) { + LOGGER.severe(Common.addTag("[startFLJob] the parameter of is null or the length " + + "is not valid (should be >= 4), please check!")); + throw new IllegalArgumentException(); + } + try { + String keyAttestationPem = PkiUtil.getPemFormat(certificates[0]); + String equipCertPem = PkiUtil.getPemFormat(certificates[1]); + String equipCACertPem = PkiUtil.getPemFormat(certificates[2]); + String rootCertPem = PkiUtil.getPemFormat(certificates[3]); + + this.keyAttestationOffset = this.builder.createString(keyAttestationPem); + this.equipCertOffset = this.builder.createString(equipCertPem); + this.equipCACertOffset = this.builder.createString(equipCACertPem); + this.rootCertOffset = this.builder.createString(rootCertPem); + } catch (IOException e) { + LOGGER.severe(Common.addTag("[StartFLJob] catch IOException in certificateChain: " + e.getMessage())); + } + return this; + } + /** * build protobuffer * @@ -152,210 +612,4 @@ public class StartFLJob { return builder.sizedByteArray(); } } - - /** - * getInstance of StartFLJob - * - * @return StartFLJob instance - */ - public static StartFLJob getInstance() { - StartFLJob localRef = startFLJob; - if (localRef == null) { - synchronized (StartFLJob.class) { - localRef = startFLJob; - if (localRef == null) { - startFLJob = localRef = new StartFLJob(); - } - } - } - return localRef; - } - - public String getNextRequestTime() { - return nextRequestTime; - } - - /** - * get request start FLJob - * - * @param dataSize dataSize - * @param iteration iteration - * @param time time - * @return byte[] data - */ - public byte[] getRequestStartFLJob(int dataSize, int iteration, long time) { - RequestStartFLJobBuilder builder = new RequestStartFLJobBuilder(); - return builder.flName(flParameter.getFlName()) - .time(time) - .id(localFLParameter.getFlID()) - .dataSize(dataSize) - .iteration(iteration) - .build(); - } - - public int getFeatureSize() { - return featureSize; - } - - public ArrayList getEncryptFeatureName() { - return encryptFeatureName; - } - - - private FLClientStatus parseResponseAlbert(ResponseFLJob flJob) { - int fmCount = flJob.featureMapLength(); - encryptFeatureName.clear(); - if (fmCount <= 0) { - LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero")); - return FLClientStatus.FAILED; - } - - if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { - LOGGER.info(Common.addTag("[startFLJob] parseResponseAlbert by " + localFLParameter.getServerMod())); - ArrayList albertFeatureMaps = new ArrayList(); - ArrayList inferFeatureMaps = new ArrayList(); - for (int i = 0; i < fmCount; i++) { - FeatureMap feature = flJob.featureMap(i); - if (feature == null) { - LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null")); - return FLClientStatus.FAILED; - } - String featureName = feature.weightFullname(); - if (localFLParameter.getAlbertWeightName().contains(featureName)) { - albertFeatureMaps.add(feature); - inferFeatureMaps.add(feature); - featureSize += feature.dataLength(); - encryptFeatureName.add(feature.weightFullname()); - } else if (localFLParameter.getClassifierWeightName().contains(featureName)) { - inferFeatureMaps.add(feature); - } else { - continue; - } - LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " + - "weightLength: " + feature.dataLength())); - } - int tag = 0; - LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference " + - "model-----------------")); - AlInferBert alInferBert = AlInferBert.getInstance(); - tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(), - inferFeatureMaps); - if (tag == -1) { - LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in ")); - return FLClientStatus.FAILED; - } - LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------")); - AlTrainBert alTrainBert = AlTrainBert.getInstance(); - tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), - albertFeatureMaps); - if (tag == -1) { - LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in ")); - return FLClientStatus.FAILED; - } - } else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) { - LOGGER.info(Common.addTag("[startFLJob] parseResponseAlbert by " + localFLParameter.getServerMod())); - ArrayList featureMaps = new ArrayList(); - for (int i = 0; i < fmCount; i++) { - FeatureMap feature = flJob.featureMap(i); - if (feature == null) { - LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null")); - return FLClientStatus.FAILED; - } - String featureName = feature.weightFullname(); - featureMaps.add(feature); - featureSize += feature.dataLength(); - encryptFeatureName.add(featureName); - LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " + - "weightLength: " + feature.dataLength())); - } - int tag = 0; - LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------")); - AlTrainBert alTrainBert = AlTrainBert.getInstance(); - tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), - featureMaps); - if (tag == -1) { - LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in ")); - return FLClientStatus.FAILED; - } - } - return FLClientStatus.SUCCESS; - } - - private FLClientStatus parseResponseLenet(ResponseFLJob flJob) { - int fmCount = flJob.featureMapLength(); - ArrayList featureMaps = new ArrayList(); - encryptFeatureName.clear(); - for (int i = 0; i < fmCount; i++) { - FeatureMap feature = flJob.featureMap(i); - if (feature == null) { - LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null")); - return FLClientStatus.FAILED; - } - String featureName = feature.weightFullname(); - featureMaps.add(feature); - featureSize += feature.dataLength(); - encryptFeatureName.add(featureName); - LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + - feature.weightFullname() + ", weightLength: " + feature.dataLength())); - } - int tag = 0; - LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------")); - TrainLenet trainLenet = TrainLenet.getInstance(); - tag = SessionUtil.updateFeatures(trainLenet.getTrainSession(), flParameter.getTrainModelPath(), featureMaps); - if (tag == -1) { - LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in ")); - return FLClientStatus.FAILED; - } - return FLClientStatus.SUCCESS; - } - - /** - * response res - * - * @param flJob ResponseFLJob - * @return FLClientStatus - */ - public FLClientStatus doResponse(ResponseFLJob flJob) { - if (flJob == null) { - LOGGER.severe(Common.addTag("[startFLJob] the input parameter flJob is null")); - return FLClientStatus.FAILED; - } - FLPlan flPlanConfig = flJob.flPlanConfig(); - if (flPlanConfig == null) { - LOGGER.severe(Common.addTag("[startFLJob] the flPlanConfig is null")); - return FLClientStatus.FAILED; - } - LOGGER.info(Common.addTag("[startFLJob] return retCode: " + flJob.retcode())); - LOGGER.info(Common.addTag("[startFLJob] reason: " + flJob.reason())); - LOGGER.info(Common.addTag("[startFLJob] iteration: " + flJob.iteration())); - LOGGER.info(Common.addTag("[startFLJob] is selected: " + flJob.isSelected())); - LOGGER.info(Common.addTag("[startFLJob] next request time: " + flJob.nextReqTime())); - nextRequestTime = flJob.nextReqTime(); - LOGGER.info(Common.addTag("[startFLJob] timestamp: " + flJob.timestamp())); - FLClientStatus status = FLClientStatus.SUCCESS; - int retCode = flJob.retcode(); - - switch (retCode) { - case (ResponseCode.SUCCEED): - localFLParameter.setServerMod(flPlanConfig.serverMode()); - if (ALBERT.equals(flParameter.getFlName())) { - LOGGER.info(Common.addTag("[startFLJob] into ")); - status = parseResponseAlbert(flJob); - } - if (LENET.equals(flParameter.getFlName())) { - LOGGER.info(Common.addTag("[startFLJob] into ")); - status = parseResponseLenet(flJob); - } - return status; - case (ResponseCode.OutOfTime): - return FLClientStatus.RESTART; - case (ResponseCode.RequestError): - case (ResponseCode.SystemError): - LOGGER.info(Common.addTag("[startFLJob] catch RequestError or SystemError")); - return FLClientStatus.FAILED; - default: - LOGGER.severe(Common.addTag("[startFLJob] the return from server is invalid: " + retCode)); - return FLClientStatus.FAILED; - } - } } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java index 22541977773..32ff4d6bb62 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java @@ -18,19 +18,26 @@ package com.mindspore.flclient; import static com.mindspore.flclient.FLParameter.SLEEP_TIME; import static com.mindspore.flclient.LocalFLParameter.ALBERT; +import static com.mindspore.flclient.LocalFLParameter.ANDROID; import static com.mindspore.flclient.LocalFLParameter.LENET; import com.mindspore.flclient.model.AlInferBert; import com.mindspore.flclient.model.AlTrainBert; +import com.mindspore.flclient.model.Client; +import com.mindspore.flclient.model.ClientManager; +import com.mindspore.flclient.model.RunType; import com.mindspore.flclient.model.SessionUtil; +import com.mindspore.flclient.model.Status; import com.mindspore.flclient.model.TrainLenet; - +import com.mindspore.flclient.pki.PkiUtil; +import com.mindspore.lite.config.CpuBindMode; import mindspore.schema.ResponseGetModel; import java.nio.ByteBuffer; import java.security.SecureRandom; import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.logging.Logger; @@ -46,8 +53,38 @@ public class SyncFLJob { private FLParameter flParameter = FLParameter.getInstance(); private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); - private FLJobResultCallback flJobResultCallback = new FLJobResultCallback(); - private Map oldFeatureMap; + private IFLJobResultCallback flJobResultCallback; + private FLClientStatus curStatus; + + private void initFlIDForPkiVerify() { + if (flParameter.isPkiVerify()) { + LOGGER.info(Common.addTag("pkiVerify mode is open!")); + String equipCertHash = PkiUtil.genEquipCertHash(flParameter.getClientID()); + if (equipCertHash == null || equipCertHash.isEmpty()) { + LOGGER.severe(Common.addTag("equipCertHash is empty, please check your mobile phone, only Huawei " + + "phones are supported now.")); + throw new IllegalArgumentException(); + } + LOGGER.info(Common.addTag("flID for pki verify is: " + equipCertHash)); + localFLParameter.setFlID(equipCertHash); + } else { + LOGGER.info(Common.addTag("pkiVerify mode is not open!")); + localFLParameter.setFlID(flParameter.getClientID()); + } + } + + public SyncFLJob() { + if (!Common.checkFLName(flParameter.getFlName())) { + try { + LOGGER.info(Common.addTag("the flName: " + flParameter.getFlName())); + Class.forName(flParameter.getFlName()); + } catch (ClassNotFoundException e) { + LOGGER.severe(Common.addTag("catch ClassNotFoundException error, the set flName does not exist, please " + + "check: " + e.getMessage())); + throw new IllegalArgumentException(); + } + } + } /** * Starts a federated learning task on the device. @@ -55,209 +92,172 @@ public class SyncFLJob { * @return the status code corresponding to the response message. */ public FLClientStatus flJobRun() { - Common.setSecureRandom(new SecureRandom()); - localFLParameter.setFlID(flParameter.getClientID()); - FLLiteClient client = new FLLiteClient(); - FLClientStatus curStatus; - curStatus = client.initSession(); - if (curStatus == FLClientStatus.FAILED) { - LOGGER.severe(Common.addTag("init session failed")); - flJobResultCallback.onFlJobFinished(flParameter.getFlName(), client.getIterations(), client.getRetCode()); - return curStatus; + flJobResultCallback = flParameter.getIflJobResultCallback(); + if (!Common.checkFLName(flParameter.getFlName()) && ANDROID.equals(flParameter.getDeployEnv())) { + Common.setSecureRandom(Common.getFastSecureRandom()); + } else { + Common.setSecureRandom(new SecureRandom()); } - + initFlIDForPkiVerify(); + localFLParameter.setMsConfig(0, flParameter.getThreadNum(), flParameter.getCpuBindMode(), false); + FLLiteClient flLiteClient = new FLLiteClient(); + LOGGER.info(Common.addTag("recovery StopJobFlag to false in the start of fl job")); + localFLParameter.setStopJobFlag(false); do { - LOGGER.info(Common.addTag("flName: " + flParameter.getFlName())); - int trainDataSize = client.setInput(flParameter.getTrainDataset()); - if (trainDataSize <= 0) { - LOGGER.severe(Common.addTag("unsolved error code in : the return trainDataSize<=0")); - curStatus = FLClientStatus.FAILED; - flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(), - client.getRetCode()); + if (checkStopJobFlag()) { break; } - client.setTrainDataSize(trainDataSize); + LOGGER.info(Common.addTag("flName: " + flParameter.getFlName())); + int trainDataSize = flLiteClient.setInput(); + if (trainDataSize <= 0) { + curStatus = FLClientStatus.FAILED; + failed("unsolved error code in : the return trainDataSize<=0, setInput", + flLiteClient); + break; + } + flLiteClient.setTrainDataSize(trainDataSize); // startFLJob - curStatus = startFLJob(client); + curStatus = startFLJob(flLiteClient); if (curStatus == FLClientStatus.RESTART) { - restart("[startFLJob]", client.getNextRequestTime(), client.getIteration(), client.getRetCode()); + restart("[startFLJob]", flLiteClient.getNextRequestTime(), flLiteClient); continue; } else if (curStatus == FLClientStatus.FAILED) { - failed("[startFLJob]", client.getIteration(), client.getRetCode(), curStatus); + failed("[startFLJob]", flLiteClient); break; } - LOGGER.info(Common.addTag("[startFLJob] startFLJob succeed, curIteration: " + client.getIteration())); - - // get the feature map before train - getOldFeatureMap(client); + LOGGER.info(Common.addTag("[startFLJob] startFLJob succeed, curIteration: " + flLiteClient.getIteration())); // create mask - curStatus = client.getFeatureMask(); + curStatus = flLiteClient.getFeatureMask(); if (curStatus == FLClientStatus.RESTART) { - restart("[Encrypt] creatMask", client.getNextRequestTime(), client.getIteration(), client.getRetCode()); + restart("[Encrypt] creatMask", flLiteClient.getNextRequestTime(), flLiteClient); continue; } else if (curStatus == FLClientStatus.FAILED) { - failed("[Encrypt] createMask", client.getIteration(), client.getRetCode(), curStatus); + failed("[Encrypt] createMask", flLiteClient); break; } // train - curStatus = client.localTrain(); + curStatus = flLiteClient.localTrain(); if (curStatus == FLClientStatus.FAILED) { - failed("[train] train", client.getIteration(), client.getRetCode(), curStatus); + failed("[train] train", flLiteClient); break; } LOGGER.info(Common.addTag("[train] train succeed")); // updateModel - curStatus = updateModel(client); + curStatus = updateModel(flLiteClient); if (curStatus == FLClientStatus.RESTART) { - restart("[updateModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode()); + restart("[updateModel]", flLiteClient.getNextRequestTime(), flLiteClient); continue; } else if (curStatus == FLClientStatus.FAILED) { - failed("[updateModel] updateModel", client.getIteration(), client.getRetCode(), curStatus); + failed("[updateModel] updateModel", flLiteClient); break; } // unmasking - curStatus = client.unMasking(); + curStatus = flLiteClient.unMasking(); if (curStatus == FLClientStatus.RESTART) { - restart("[Encrypt] unmasking", client.getNextRequestTime(), client.getIteration(), client.getRetCode()); + restart("[Encrypt] unmasking", flLiteClient.getNextRequestTime(), flLiteClient); continue; } else if (curStatus == FLClientStatus.FAILED) { - failed("[Encrypt] unmasking", client.getIteration(), client.getRetCode(), curStatus); + failed("[Encrypt] unmasking", flLiteClient); break; } // getModel - curStatus = getModel(client); + curStatus = getModel(flLiteClient); if (curStatus == FLClientStatus.RESTART) { - restart("[getModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode()); + restart("[getModel]", flLiteClient.getNextRequestTime(), flLiteClient); continue; } else if (curStatus == FLClientStatus.FAILED) { - failed("[getModel] getModel", client.getIteration(), client.getRetCode(), curStatus); + failed("[getModel] getModel", flLiteClient); break; } // get the feature map after averaging and update dp_norm_clip - updateDpNormClip(client); + flLiteClient.updateDpNormClip(); // evaluate model after getting model from server - if (flParameter.getTestDataset().equals("null")) { - LOGGER.info(Common.addTag("[evaluate] the testDataset is null, don't evaluate the model after getting" + - " model from server")); + if (!checkEvalPath()) { + LOGGER.info(Common.addTag("[evaluate] the data map set by user do not contain evaluation dataset, " + + "don't evaluate the model after getting model from server")); } else { - curStatus = client.evaluateModel(); + curStatus = flLiteClient.evaluateModel(); if (curStatus == FLClientStatus.FAILED) { - failed("[evaluate] evaluate", client.getIteration(), client.getRetCode(), curStatus); + failed("[evaluate] evaluate", flLiteClient); break; } LOGGER.info(Common.addTag("[evaluate] evaluate succeed")); } LOGGER.info(Common.addTag("========================================================the total response of " - + client.getIteration() + ": " + curStatus + + + flLiteClient.getIteration() + ": " + curStatus + "======================================================================")); - flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(), - client.getRetCode()); - } while (client.getIteration() < client.getIterations()); - client.freeSession(); + flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), flLiteClient.getIteration(), + flLiteClient.getRetCode()); + Common.freeSession(); + } while (flLiteClient.getIteration() < flLiteClient.getIterations()); LOGGER.info(Common.addTag("flJobRun finish")); - flJobResultCallback.onFlJobFinished(flParameter.getFlName(), client.getIterations(), client.getRetCode()); + flJobResultCallback.onFlJobFinished(flParameter.getFlName(), flLiteClient.getIterations(), + flLiteClient.getRetCode()); return curStatus; } - private FLClientStatus startFLJob(FLLiteClient client) { - FLClientStatus curStatus = client.startFLJob(); + private FLClientStatus startFLJob(FLLiteClient flLiteClient) { + FLClientStatus curStatus = flLiteClient.startFLJob(); while (curStatus == FLClientStatus.WAIT) { waitSomeTime(); - curStatus = client.startFLJob(); + curStatus = flLiteClient.startFLJob(); } return curStatus; } - private FLClientStatus updateModel(FLLiteClient client) { - FLClientStatus curStatus = client.updateModel(); + private FLClientStatus updateModel(FLLiteClient flLiteClient) { + FLClientStatus curStatus = flLiteClient.updateModel(); while (curStatus == FLClientStatus.WAIT) { waitSomeTime(); - curStatus = client.updateModel(); + curStatus = flLiteClient.updateModel(); } return curStatus; } - private FLClientStatus getModel(FLLiteClient client) { - FLClientStatus curStatus = client.getModel(); + private FLClientStatus getModel(FLLiteClient flLiteClient) { + FLClientStatus curStatus = flLiteClient.getModel(); while (curStatus == FLClientStatus.WAIT) { waitSomeTime(); - curStatus = client.getModel(); + curStatus = flLiteClient.getModel(); } return curStatus; } - private void updateDpNormClip(FLLiteClient client) { - EncryptLevel encryptLevel = localFLParameter.getEncryptLevel(); - if (encryptLevel == EncryptLevel.DP_ENCRYPT) { - int currentIter = client.getIteration(); - Map fedFeatureMap = getFeatureMap(); - float fedWeightUpdateNorm = calWeightUpdateNorm(oldFeatureMap, fedFeatureMap); - if (fedWeightUpdateNorm == -1) { - LOGGER.severe(Common.addTag("[updateDpNormClip] the returned value fedWeightUpdateNorm is not valid: " + - "-1, please check!")); - throw new IllegalArgumentException(); + private boolean checkEvalPath() { + boolean tag = true; + if (Common.checkFLName(flParameter.getFlName())) { + if ("null".equals(flParameter.getTestDataset())) { + tag = false; } - LOGGER.info(Common.addTag("[DP] L2-norm of weights' average update is: " + fedWeightUpdateNorm)); - float newNormCLip = (float) client.getDpNormClipFactor() * fedWeightUpdateNorm; - if (currentIter == 1) { - client.setDpNormClipAdapt(newNormCLip); - LOGGER.info(Common.addTag("[DP] dpNormClip has been updated.")); - } else { - if (newNormCLip < client.getDpNormClipAdapt()) { - client.setDpNormClipAdapt(newNormCLip); - LOGGER.info(Common.addTag("[DP] dpNormClip has been updated.")); - } - } - LOGGER.info(Common.addTag("[DP] Adaptive dpNormClip is: " + client.getDpNormClipAdapt())); + return tag; + } + if (!flParameter.getDataMap().containsKey(RunType.EVALMODE)) { + LOGGER.info(Common.addTag("[evaluate] the data map set by user do not contain evaluation dataset, " + + "don't evaluate the model after getting model from server")); + tag = false; + return tag; + } + return tag; + } + + private boolean checkStopJobFlag() { + if (localFLParameter.isStopJobFlag()) { + LOGGER.info(Common.addTag("the stopJObFlag is set to true, the job will be stop")); + curStatus = FLClientStatus.FAILED; + return true; + } else { + return false; } } - private void getOldFeatureMap(FLLiteClient client) { - EncryptLevel encryptLevel = localFLParameter.getEncryptLevel(); - if (encryptLevel == EncryptLevel.DP_ENCRYPT) { - Map featureMap = getFeatureMap(); - oldFeatureMap = client.getOldMapCopy(featureMap); - } - } - - private float calWeightUpdateNorm(Map originalData, Map newData) { - float updateL2Norm = 0f; - for (String key : originalData.keySet()) { - float[] data = originalData.get(key); - float[] dataAfterUpdate = newData.get(key); - for (int j = 0; j < data.length; j++) { - if (j >= dataAfterUpdate.length) { - LOGGER.severe("[calWeightUpdateNorm] the index j is out of range for array dataAfterUpdate, " + - "please check"); - return -1; - } - float updateData = data[j] - dataAfterUpdate[j]; - updateL2Norm += updateData * updateData; - } - } - updateL2Norm = (float) Math.sqrt(updateL2Norm); - return updateL2Norm; - } - - private Map getFeatureMap() { - Map featureMap = new HashMap<>(); - if (flParameter.getFlName().equals(ALBERT)) { - AlTrainBert alTrainBert = AlTrainBert.getInstance(); - featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession())); - } else if (flParameter.getFlName().equals(LENET)) { - TrainLenet trainLenet = TrainLenet.getInstance(); - featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); - } - return featureMap; - } /** * Starts an inference task on the device. @@ -265,6 +265,292 @@ public class SyncFLJob { * @return the status code corresponding to the response message. */ public int[] modelInference() { + if (Common.checkFLName(flParameter.getFlName())) { + return deprecatedModelInference(); + } + + Client client = ClientManager.getClient(flParameter.getFlName()); + localFLParameter.setMsConfig(0, flParameter.getThreadNum(), flParameter.getCpuBindMode(), false); + localFLParameter.setStopJobFlag(false); + int[] labels = new int[0]; + Map dataSize = client.initDataSets(flParameter.getDataMap()); + if (dataSize.isEmpty()) { + LOGGER.severe("[model inference] initDataSets failed, please check"); + return new int[0]; + } + Status tag = client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig()); + if (!Status.SUCCESS.equals(tag)) { + LOGGER.severe(Common.addTag("[model inference] unsolved error code in : the return " + + " status is: " + tag)); + return new int[0]; + } + client.setBatchSize(flParameter.getBatchSize()); + LOGGER.info(Common.addTag("===========model inference=============")); + labels = client.inferModel().stream().mapToInt(Integer::valueOf).toArray(); + if (labels == null || labels.length == 0) { + LOGGER.severe("[model inference] the returned label from client.inferModel() is null, please " + + "check"); + } + LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels))); + client.free(); + LOGGER.info(Common.addTag("[model inference] inference finish")); + return labels; + } + + /** + * Obtains the latest model on the cloud. + * + * @return the status code corresponding to the response message. + */ + public FLClientStatus getModel() { + if (!Common.checkFLName(flParameter.getFlName()) && ANDROID.equals(flParameter.getDeployEnv())) { + Common.setSecureRandom(Common.getFastSecureRandom()); + } else { + Common.setSecureRandom(new SecureRandom()); + } + if (Common.checkFLName(flParameter.getFlName())) { + return deprecatedGetModel(); + } + localFLParameter.setServerMod(flParameter.getServerMod().toString()); + localFLParameter.setMsConfig(0, 1, 0, false); + FLClientStatus status; + FLLiteClient flLiteClient = new FLLiteClient(); + status = flLiteClient.getModel(); + return status; + } + + /** + * use to stop FL job. + */ + public void stopFLJob() { + LOGGER.info(Common.addTag("will stop the flJob")); + localFLParameter.setStopJobFlag(true); + Common.notifyObject(); + } + + private void waitSomeTime() { + if (flParameter.getSleepTime() != 0) { + Common.sleep(flParameter.getSleepTime()); + } else { + Common.sleep(SLEEP_TIME); + } + } + + private void waitNextReqTime(String nextReqTime) { + long waitTime = Common.getWaitTime(nextReqTime); + Common.sleep(waitTime); + } + + private void restart(String tag, String nextReqTime, FLLiteClient flLiteClient) { + LOGGER.info(Common.addTag(tag + " out of time: need wait and request startFLJob again")); + waitNextReqTime(nextReqTime); + Common.freeSession(); + flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), flLiteClient.getIteration(), + flLiteClient.getRetCode()); + } + + private void failed(String tag, FLLiteClient flLiteClient) { + LOGGER.info(Common.addTag(tag + " failed")); + LOGGER.info(Common.addTag("=========================================the total response of " + + flLiteClient.getIteration() + ": " + curStatus + "=========================================")); + Common.freeSession(); + flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), flLiteClient.getIteration(), + flLiteClient.getRetCode()); + } + + private static Map> createDatasetMap(String trainDataPath, String evalDataPath, + String inferDataPath, String pathRegex) { + Map> dataMap = new HashMap<>(); + if ((trainDataPath == null) || ("null".equals(trainDataPath)) || (trainDataPath.isEmpty())) { + LOGGER.info(Common.addTag("the trainDataPath is null or empty, please check if you are in the case of " + + "noly inference")); + } else { + dataMap.put(RunType.TRAINMODE, Arrays.asList(trainDataPath.split(pathRegex))); + LOGGER.info(Common.addTag("the trainDataPath: " + Arrays.toString(trainDataPath.split(pathRegex)))); + } + + if ((evalDataPath == null) || ("null".equals(evalDataPath)) || (evalDataPath.isEmpty())) { + LOGGER.info(Common.addTag("the evalDataPath is null or empty, please check if you are in the case of only" + + " trainning without evaluation")); + } else { + dataMap.put(RunType.EVALMODE, Arrays.asList(evalDataPath.split(pathRegex))); + LOGGER.info(Common.addTag("the evalDataPath: " + Arrays.toString(evalDataPath.split(pathRegex)))); + } + + if ((inferDataPath == null) || ("null".equals(inferDataPath)) || (inferDataPath.isEmpty())) { + LOGGER.info(Common.addTag("the inferDataPath is null or empty, please check if you are in the case of " + + "trainning without inference")); + } else { + dataMap.put(RunType.INFERMODE, Arrays.asList(inferDataPath.split(pathRegex))); + LOGGER.info(Common.addTag("the inferDataPath: " + Arrays.toString(inferDataPath.split(pathRegex)))); + } + return dataMap; + } + + private static void createWeightNameList(String trainWeightName, String inferWeightName, String nameRegex, + FLParameter flParameter) { + if ((trainWeightName == null) || ("null".equals(trainWeightName)) || (trainWeightName.isEmpty())) { + LOGGER.info(Common.addTag("the trainWeightName is null or empty, only need in " + ServerMod.HYBRID_TRAINING)); + } else { + flParameter.setHybridWeightName(Arrays.asList(trainWeightName.split(nameRegex)), RunType.TRAINMODE); + LOGGER.info(Common.addTag("the trainWeightName: " + Arrays.toString(trainWeightName.split(nameRegex)))); + } + + if ((inferWeightName == null) || ("null".equals(inferWeightName)) || (inferWeightName.isEmpty())) { + LOGGER.info(Common.addTag("the inferWeightName is null or empty, only need in " + ServerMod.HYBRID_TRAINING)); + } else { + flParameter.setHybridWeightName(Arrays.asList(inferWeightName.split(nameRegex)), RunType.INFERMODE); + LOGGER.info(Common.addTag("the trainWeightName: " + Arrays.toString(inferWeightName.split(nameRegex)))); + } + } + + private static void task(String[] args) { + String trainDataPath = args[0]; + String evalDataPath = args[1]; + String inferDataPath = args[2]; + String pathRegex = args[3]; + + String flName = args[4]; + String trainModelPath = args[5]; + String inferModelPath = args[6]; + String sslProtocol = args[7]; + String deployEnv = args[8]; + String domainName = args[9]; + String certPath = args[10]; + boolean useElb = Boolean.parseBoolean(args[11]); + int serverNum = Integer.parseInt(args[12]); + String task = args[13]; + int threadNum = Integer.parseInt(args[14]); + String cpuBindMode = args[15]; + String trainWeightName = args[16]; + String inferWeightName = args[17]; + String nameRegex = args[18]; + String serverMod = args[19]; + int batchSize = Integer.parseInt(args[20]); + + FLParameter flParameter = FLParameter.getInstance(); + + // create dataset of map + Map> dataMap = createDatasetMap(trainDataPath, evalDataPath, inferDataPath, pathRegex); + + // create weight name of list + createWeightNameList(trainWeightName, inferWeightName, nameRegex, flParameter); + + flParameter.setFlName(flName); + SyncFLJob syncFLJob = new SyncFLJob(); + Common.setIsHttps(domainName.split("//")[0].split(":")[0]); + switch (task) { + case "train": + LOGGER.info(Common.addTag("start syncFLJob.flJobRun()")); + if (Common.isHttps()) { + flParameter.setCertPath(certPath); + } + flParameter.setDataMap(dataMap); + flParameter.setTrainModelPath(trainModelPath); + flParameter.setInferModelPath(inferModelPath); + flParameter.setSslProtocol(sslProtocol); + flParameter.setDeployEnv(deployEnv); + flParameter.setDomainName(domainName); + flParameter.setUseElb(useElb); + flParameter.setServerNum(serverNum); + flParameter.setThreadNum(threadNum); + flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode)); + flParameter.setBatchSize(batchSize); + syncFLJob.flJobRun(); + break; + case "inference": + LOGGER.info(Common.addTag("start syncFLJob.modelInference()")); + flParameter.setDataMap(dataMap); + flParameter.setInferModelPath(inferModelPath); + flParameter.setThreadNum(threadNum); + flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode)); + flParameter.setBatchSize(batchSize); + syncFLJob.modelInference(); + break; + case "getModel": + LOGGER.info(Common.addTag("start syncFLJob.getModel()")); + if (Common.isHttps()) { + flParameter.setCertPath(certPath); + } + flParameter.setTrainModelPath(trainModelPath); + flParameter.setInferModelPath(inferModelPath); + flParameter.setSslProtocol(sslProtocol); + flParameter.setDeployEnv(deployEnv); + flParameter.setDomainName(domainName); + flParameter.setUseElb(useElb); + flParameter.setServerNum(serverNum); + flParameter.setServerMod(ServerMod.valueOf(serverMod)); + syncFLJob.getModel(); + break; + default: + LOGGER.info(Common.addTag("do not do any thing!")); + } + } + + + private static void deprecatedTask(String[] args) { + String trainDataset = args[0]; + String vocabFile = args[1]; + String idsFile = args[2]; + String testDataset = args[3]; + String flName = args[4]; + String trainModelPath = args[5]; + String inferModelPath = args[6]; + boolean useSSL = Boolean.parseBoolean(args[7]); + String domainName = args[8]; + boolean useElb = Boolean.parseBoolean(args[9]); + int serverNum = Integer.parseInt(args[10]); + String certPath = args[11]; + String task = args[12]; + + FLParameter flParameter = FLParameter.getInstance(); + flParameter.setFlName(flName); + SyncFLJob syncFLJob = new SyncFLJob(); + Common.setIsHttps(domainName.split("//")[0].split(":")[0]); + switch (task) { + case "train": + if (Common.isHttps()) { + flParameter.setCertPath(certPath); + } + flParameter.setTrainDataset(trainDataset); + flParameter.setTrainModelPath(trainModelPath); + flParameter.setTestDataset(testDataset); + flParameter.setInferModelPath(inferModelPath); + flParameter.setDomainName(domainName); + flParameter.setUseElb(useElb); + flParameter.setServerNum(serverNum); + if (ALBERT.equals(flName)) { + flParameter.setVocabFile(vocabFile); + flParameter.setIdsFile(idsFile); + } + syncFLJob.flJobRun(); + break; + case "inference": + flParameter.setTestDataset(testDataset); + flParameter.setInferModelPath(inferModelPath); + if (ALBERT.equals(flName)) { + flParameter.setVocabFile(vocabFile); + flParameter.setIdsFile(idsFile); + } + syncFLJob.modelInference(); + break; + case "getModel": + if (Common.isHttps()) { + flParameter.setCertPath(certPath); + } + flParameter.setTrainModelPath(trainModelPath); + flParameter.setInferModelPath(inferModelPath); + flParameter.setDomainName(domainName); + flParameter.setUseElb(useElb); + flParameter.setServerNum(serverNum); + syncFLJob.getModel(); + break; + default: + LOGGER.info(Common.addTag("do not do any thing!")); + } + } + + private int[] deprecatedModelInference() { int[] labels = new int[0]; if (flParameter.getFlName().equals(ALBERT)) { AlInferBert alInferBert = AlInferBert.getInstance(); @@ -290,166 +576,27 @@ public class SyncFLJob { LOGGER.info(Common.addTag("[model inference] inference finish")); } return labels; + } - /** - * Obtains the latest model on the cloud. - * - * @return the status code corresponding to the response message. - */ - public FLClientStatus getModel() { - Common.setSecureRandom(Common.getFastSecureRandom()); - int tag = 0; + private FLClientStatus deprecatedGetModel() { + localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString()); FLClientStatus status; - try { - if (flParameter.getFlName().equals(ALBERT)) { - localFLParameter.setServerMod(ServerMod.HYBRID_TRAINING.toString()); - LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + - flParameter.getTrainModelPath() + " Create Train Session=============")); - AlTrainBert alTrainBert = AlTrainBert.getInstance(); - tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true); - if (tag == -1) { - LOGGER.severe(Common.addTag("[initSession] unsolved error code in : the " + - "return is -1")); - return FLClientStatus.FAILED; - } - LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " + - flParameter.getInferModelPath() + " Create inference Session=============")); - AlInferBert alInferBert = AlInferBert.getInstance(); - tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false); - } else if (flParameter.getFlName().equals(LENET)) { - localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString()); - LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + - flParameter.getTrainModelPath() + " Create Train Session=============")); - TrainLenet trainLenet = TrainLenet.getInstance(); - tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true); - } - if (tag == -1) { - LOGGER.severe(Common.addTag("[initSession] unsolved error code in : the return " + - "is -1")); - return FLClientStatus.FAILED; - } - FLCommunication flCommunication = FLCommunication.getInstance(); - String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), - flParameter.getDomainName()); - GetModel getModelBuf = GetModel.getInstance(); - byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), 0); - byte[] message = flCommunication.syncRequest(url + "/getModel", buffer); - if (!Common.isSeverReady(message)) { - LOGGER.info(Common.addTag("[getModel] the server is not ready now, need wait some time and request " + - "again")); - status = FLClientStatus.WAIT; - return status; - } - LOGGER.info(Common.addTag("[getModel] get model request success")); - ByteBuffer debugBuffer = ByteBuffer.wrap(message); - ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer); - status = getModelBuf.doResponse(responseDataBuf); - LOGGER.info(Common.addTag("[getModel] success!")); - } catch (Exception ex) { - LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + ex.getMessage())); - status = FLClientStatus.FAILED; - } - if (flParameter.getFlName().equals(ALBERT)) { - LOGGER.info(Common.addTag("===========free train session=============")); - AlTrainBert alTrainBert = AlTrainBert.getInstance(); - SessionUtil.free(alTrainBert.getTrainSession()); - LOGGER.info(Common.addTag("===========free inference session=============")); - AlInferBert alInferBert = AlInferBert.getInstance(); - SessionUtil.free(alInferBert.getTrainSession()); - } else if (flParameter.getFlName().equals(LENET)) { - LOGGER.info(Common.addTag("===========free session=============")); - TrainLenet trainLenet = TrainLenet.getInstance(); - SessionUtil.free(trainLenet.getTrainSession()); - } + FLLiteClient flLiteClient = new FLLiteClient(); + status = flLiteClient.getModel(); return status; } - private void waitSomeTime() { - if (flParameter.getSleepTime() != 0) { - Common.sleep(flParameter.getSleepTime()); - } else { - Common.sleep(SLEEP_TIME); - } - } - - private void waitNextReqTime(String nextReqTime) { - long waitTime = Common.getWaitTime(nextReqTime); - Common.sleep(waitTime); - } - - private void restart(String tag, String nextReqTime, int iteration, int retcode) { - LOGGER.info(Common.addTag(tag + " out of time: need wait and request startFLJob again")); - waitNextReqTime(nextReqTime); - flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), iteration, retcode); - } - - private void failed(String tag, int iteration, int retcode, FLClientStatus curStatus) { - LOGGER.info(Common.addTag(tag + " failed")); - LOGGER.info(Common.addTag("=========================================the total response of " + - iteration + ": " + curStatus + "=========================================")); - flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), iteration, retcode); - } - public static void main(String[] args) { - String trainDataset = args[0]; - String vocabFile = args[1]; - String idsFile = args[2]; - String testDataset = args[3]; - String flName = args[4]; - String trainModelPath = args[5]; - String inferModelPath = args[6]; - boolean useSSL = Boolean.parseBoolean(args[7]); - String domainName = args[8]; - boolean useElb = Boolean.parseBoolean(args[9]); - int serverNum = Integer.parseInt(args[10]); - String certPath = args[11]; - String task = args[12]; - - FLParameter flParameter = FLParameter.getInstance(); - - SyncFLJob syncFLJob = new SyncFLJob(); - if (task.equals("train")) { - if (useSSL) { - flParameter.setCertPath(certPath); - } - flParameter.setTrainDataset(trainDataset); - flParameter.setFlName(flName); - flParameter.setTrainModelPath(trainModelPath); - flParameter.setTestDataset(testDataset); - flParameter.setInferModelPath(inferModelPath); - flParameter.setUseSSL(useSSL); - flParameter.setDomainName(domainName); - flParameter.setUseElb(useElb); - flParameter.setServerNum(serverNum); - if (ALBERT.equals(flName)) { - flParameter.setVocabFile(vocabFile); - flParameter.setIdsFile(idsFile); - } - syncFLJob.flJobRun(); - } else if (task.equals("inference")) { - flParameter.setFlName(flName); - flParameter.setTestDataset(testDataset); - flParameter.setInferModelPath(inferModelPath); - if (ALBERT.equals(flName)) { - flParameter.setVocabFile(vocabFile); - flParameter.setIdsFile(idsFile); - } - syncFLJob.modelInference(); - } else if (task.equals("getModel")) { - if (useSSL) { - flParameter.setCertPath(certPath); - } - flParameter.setFlName(flName); - flParameter.setTrainModelPath(trainModelPath); - flParameter.setInferModelPath(inferModelPath); - flParameter.setUseSSL(useSSL); - flParameter.setDomainName(domainName); - flParameter.setUseElb(useElb); - flParameter.setServerNum(serverNum); - syncFLJob.getModel(); - } else { - LOGGER.info(Common.addTag("do not do any thing!")); + if (args[4] == null || args[4].isEmpty()) { + LOGGER.severe(Common.addTag("the parameter of is null, please check")); + throw new IllegalArgumentException(); } + if (Common.checkFLName(args[4])) { + deprecatedTask(args); + } else { + task(args); + } + } } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java index f7cd712e602..230b223561b 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java @@ -16,14 +16,16 @@ package com.mindspore.flclient; -import static com.mindspore.flclient.LocalFLParameter.ALBERT; -import static com.mindspore.flclient.LocalFLParameter.LENET; - import com.google.flatbuffers.FlatBufferBuilder; import com.mindspore.flclient.model.AlTrainBert; +import com.mindspore.flclient.model.Client; +import com.mindspore.flclient.model.ClientManager; +import com.mindspore.flclient.model.CommonUtils; import com.mindspore.flclient.model.SessionUtil; +import com.mindspore.flclient.model.Status; import com.mindspore.flclient.model.TrainLenet; +import com.mindspore.lite.MSTensor; import mindspore.schema.FeatureMap; import mindspore.schema.RequestUpdateModel; @@ -33,9 +35,13 @@ import mindspore.schema.ResponseUpdateModel; import java.util.ArrayList; import java.util.Date; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.logging.Logger; +import static com.mindspore.flclient.LocalFLParameter.ALBERT; +import static com.mindspore.flclient.LocalFLParameter.LENET; + /** * Define the serialization method, handle the response message returned from server for updateModel request. * @@ -52,6 +58,7 @@ public class UpdateModel { private FLParameter flParameter = FLParameter.getInstance(); private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); private FLClientStatus status; + private int retCode = ResponseCode.RequestError; private UpdateModel() { } @@ -78,6 +85,10 @@ public class UpdateModel { return status; } + public int getRetCode() { + return retCode; + } + /** * Get a flatBuffer builder of RequestUpdateModel. * @@ -88,7 +99,16 @@ public class UpdateModel { */ public byte[] getRequestUpdateFLJob(int iteration, SecureProtocol secureProtocol, int trainDataSize) { RequestUpdateModelBuilder builder = new RequestUpdateModelBuilder(localFLParameter.getEncryptLevel()); - return builder.flName(flParameter.getFlName()).time().id(localFLParameter.getFlID()) + boolean isPkiVerify = flParameter.isPkiVerify(); + if (isPkiVerify) { + Date date = new Date(); + long timestamp = date.getTime(); + String dateTime = String.valueOf(timestamp); + byte[] signature = CipherClient.signTimeAndIter(dateTime, iteration); + return builder.flName(flParameter.getFlName()).time(dateTime).id(localFLParameter.getFlID()) + .featuresMap(secureProtocol, trainDataSize).iteration(iteration).signData(signature).build(); + } + return builder.flName(flParameter.getFlName()).time("null").id(localFLParameter.getFlID()) .featuresMap(secureProtocol, trainDataSize).iteration(iteration).build(); } @@ -99,8 +119,9 @@ public class UpdateModel { * @return the status code corresponding to the response message. */ public FLClientStatus doResponse(ResponseUpdateModel response) { - LOGGER.info(Common.addTag("[updateModel] ==========updateModel response================")); - LOGGER.info(Common.addTag("[updateModel] ==========retcode: " + response.retcode())); + retCode = response.retcode(); + LOGGER.info(Common.addTag("[updateModel] ==========the response message of updateModel is================")); + LOGGER.info(Common.addTag("[updateModel] ==========retCode: " + retCode)); LOGGER.info(Common.addTag("[updateModel] ==========reason: " + response.reason())); LOGGER.info(Common.addTag("[updateModel] ==========next request time: " + response.nextReqTime())); switch (response.retcode()) { @@ -120,6 +141,65 @@ public class UpdateModel { } } + private Map getFeatureMap() { + Client client = ClientManager.getClient(flParameter.getFlName()); + status = Common.initSession(flParameter.getTrainModelPath()); + if (status == FLClientStatus.FAILED) { + retCode = ResponseCode.RequestError; + throw new IllegalArgumentException(); + } + List features = client.getFeatures(); + Map trainedMap = CommonUtils.convertTensorToFeatures(features); + LOGGER.info(Common.addTag("[updateModel] ===========free session=============")); + Common.freeSession(); + if (trainedMap.isEmpty()) { + LOGGER.severe(Common.addTag("[updateModel] the return trainedMap is empty in ")); + retCode = ResponseCode.RequestError; + status = FLClientStatus.FAILED; + throw new IllegalArgumentException(); + } + return trainedMap; + } + + private Map deprecatedGetFeatureMap() { + status = Common.initSession(flParameter.getTrainModelPath()); + if (status == FLClientStatus.FAILED) { + retCode = ResponseCode.RequestError; + throw new IllegalArgumentException(); + } + Map map = new HashMap(); + if (flParameter.getFlName().equals(ALBERT)) { + LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + + flParameter.getFlName())); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession())); + if (map.isEmpty()) { + LOGGER.severe(Common.addTag("[updateModel] the return map is empty in ")); + status = FLClientStatus.FAILED; + throw new IllegalArgumentException(); + } + } else if (flParameter.getFlName().equals(LENET)) { + LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + + flParameter.getFlName())); + TrainLenet trainLenet = TrainLenet.getInstance(); + map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); + if (map.isEmpty()) { + LOGGER.severe(Common.addTag("[updateModel] the return map is empty in ")); + status = FLClientStatus.FAILED; + throw new IllegalArgumentException(); + } + } else { + LOGGER.severe(Common.addTag("[updateModel] the flName is not valid")); + status = FLClientStatus.FAILED; + throw new IllegalArgumentException(); + } + Common.freeSession(); + return map; + } + class RequestUpdateModelBuilder { private RequestUpdateModel requestUM; private FlatBufferBuilder builder; @@ -127,6 +207,7 @@ public class UpdateModel { private int nameOffset = 0; private int idOffset = 0; private int timestampOffset = 0; + private int signDataOffset = 0; private int iteration = 0; private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT; @@ -153,12 +234,22 @@ public class UpdateModel { /** * Serialize the element timestamp in RequestUpdateModel. * + * @param setTime current timestamp when the request is sent. * @return the RequestUpdateModelBuilder object. */ - private RequestUpdateModelBuilder time() { - Date date = new Date(); - long time = date.getTime(); - this.timestampOffset = builder.createString(String.valueOf(time)); + private RequestUpdateModelBuilder time(String setTime) { + if (setTime == null || setTime.isEmpty()) { + LOGGER.severe(Common.addTag("[updateModel] the parameter of is null or empty, please " + + "check!")); + throw new IllegalArgumentException(); + } + if (setTime.equals("null")) { + Date date = new Date(); + long time = date.getTime(); + this.timestampOffset = builder.createString(String.valueOf(time)); + } else { + this.timestampOffset = builder.createString(setTime); + } return this; } @@ -189,10 +280,16 @@ public class UpdateModel { } private RequestUpdateModelBuilder featuresMap(SecureProtocol secureProtocol, int trainDataSize) { - ArrayList encryptFeatureName = secureProtocol.getEncryptFeatureName(); + ArrayList updateFeatureName = secureProtocol.getUpdateFeatureName(); + Map trainedMap = new HashMap(); + if (Common.checkFLName(flParameter.getFlName())) { + trainedMap = deprecatedGetFeatureMap(); + } else { + trainedMap = getFeatureMap(); + } switch (encryptLevel) { case PW_ENCRYPT: - int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize); + int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize, trainedMap); if (fmOffsetsPW == null || fmOffsetsPW.length == 0) { LOGGER.severe("[Encrypt] the return fmOffsetsPW from is " + "null, please check"); @@ -202,10 +299,12 @@ public class UpdateModel { LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!")); return this; case DP_ENCRYPT: - int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize); + int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize, trainedMap); if (fmOffsetsDP == null || fmOffsetsDP.length == 0) { LOGGER.severe("[Encrypt] the return fmOffsetsDP from is " + "null, please check"); + retCode = ResponseCode.RequestError; + status = FLClientStatus.FAILED; throw new IllegalArgumentException(); } this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP); @@ -213,36 +312,11 @@ public class UpdateModel { return this; case NOT_ENCRYPT: default: - Map map = new HashMap(); - if (flParameter.getFlName().equals(ALBERT)) { - LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + - flParameter.getFlName())); - AlTrainBert alTrainBert = AlTrainBert.getInstance(); - map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession())); - if (map.isEmpty()) { - LOGGER.severe(Common.addTag("[updateModel] the return map is empty in ")); - status = FLClientStatus.FAILED; - } - } else if (flParameter.getFlName().equals(LENET)) { - LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + - flParameter.getFlName())); - TrainLenet trainLenet = TrainLenet.getInstance(); - map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); - if (map.isEmpty()) { - LOGGER.severe(Common.addTag("[updateModel] the return map is empty in ")); - status = FLClientStatus.FAILED; - } - } else { - LOGGER.severe(Common.addTag("[updateModel] the flName is not valid")); - throw new IllegalArgumentException(); - } - int featureSize = encryptFeatureName.size(); + int featureSize = updateFeatureName.size(); int[] fmOffsets = new int[featureSize]; for (int i = 0; i < featureSize; i++) { - String key = encryptFeatureName.get(i); - float[] data = map.get(key); + String key = updateFeatureName.get(i); + float[] data = trainedMap.get(key); LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature " + "size: " + data.length)); for (int j = 0; j < data.length; j++) { @@ -258,6 +332,22 @@ public class UpdateModel { } } + /** + * Serialize the element signature in RequestUpdateModel. + * + * @param signData the signature Data. + * @return the RequestUpdateModelBuilder object. + */ + private RequestUpdateModelBuilder signData(byte[] signData) { + if (signData == null || signData.length == 0) { + LOGGER.severe(Common.addTag("[updateModel] the parameter of is null or empty, please " + + "check!")); + throw new IllegalArgumentException(); + } + this.signDataOffset = RequestUpdateModel.createSignatureVector(builder, signData); + return this; + } + /** * Create a flatBuffer builder of RequestUpdateModel. * @@ -270,6 +360,7 @@ public class UpdateModel { RequestUpdateModel.addTimestamp(builder, this.timestampOffset); RequestUpdateModel.addIteration(builder, this.iteration); RequestUpdateModel.addFeatureMap(builder, this.fmOffset); + RequestUpdateModel.addSignature(builder, this.signDataOffset); int root = RequestUpdateModel.endRequestUpdateModel(builder); builder.finish(root); return builder.sizedByteArray(); diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/AESEncrypt.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/AESEncrypt.java index d67368c14c4..8ba864af940 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/AESEncrypt.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/AESEncrypt.java @@ -21,7 +21,6 @@ import static com.mindspore.flclient.LocalFLParameter.KEY_LEN; import com.mindspore.flclient.Common; -import java.io.UnsupportedEncodingException; import java.security.InvalidAlgorithmParameterException; import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/CertVerify.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/CertVerify.java new file mode 100644 index 00000000000..067a018bc73 --- /dev/null +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/CertVerify.java @@ -0,0 +1,411 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mindspore.flclient.cipher; + +import com.mindspore.flclient.Common; +import com.mindspore.flclient.FLParameter; + +import org.bouncycastle.asn1.ASN1OctetString; +import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier; +import org.bouncycastle.asn1.x509.SubjectKeyIdentifier; +import org.bouncycastle.util.encoders.Hex; + +import java.io.ByteArrayInputStream; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.security.InvalidKeyException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.NoSuchProviderException; +import java.security.PublicKey; +import java.security.SignatureException; +import java.security.UnrecoverableEntryException; +import java.security.cert.CRLException; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509CRL; +import java.security.cert.X509CRLEntry; +import java.security.cert.X509Certificate; +import java.util.Base64; +import java.util.Date; +import java.util.Set; +import java.util.logging.Logger; + +/** + * Certificate verification class + * + * @since 2021-8-27 + */ +public class CertVerify { + private static final Logger LOGGER = Logger.getLogger(CertVerify.class.toString()); + + /** + * Verify the legitimacy of certificate chain + * + * @param clientID clientID of this client + * @param x509Certificates certificate chain + * @return verification result + */ + public static boolean verifyCertificateChain(String clientID, X509Certificate[] x509Certificates) { + if (clientID == null || clientID.isEmpty()) { + LOGGER.severe(Common.addTag("[CertVerify] the parameter clientID is null or empty, please check!")); + return false; + } + if (x509Certificates == null || x509Certificates.length < 2) { + LOGGER.severe(Common.addTag("[CertVerify] the parameter x509Certificates is null or the length is not " + + "valid: < 2, please check!")); + return false; + } + if (verifyChain(clientID, x509Certificates) && verifyCommonName(clientID, x509Certificates) + && verifyCrl(clientID, x509Certificates) && verifyValidDate(x509Certificates) && + verifyKeyIdentifier(clientID, x509Certificates)) { + LOGGER.info(Common.addTag("[CertVerify] verifyCertificateChain success!")); + return true; + } + LOGGER.severe(Common.addTag("[CertVerify] verifyCertificateChain failed!")); + return false; + } + + private static boolean verifyCommonName(String clientID, X509Certificate[] x509Certificate) { + if (clientID == null || clientID.isEmpty()) { + LOGGER.severe(Common.addTag("[CertVerify] the parameter clientID is null or empty, please check!")); + return false; + } + if (x509Certificate == null || x509Certificate.length < 2) { + LOGGER.severe(Common.addTag("[CertVerify] x509Certificate chains is null or the length is not valid: < 2," + + " please check!")); + return false; + } + X509Certificate[] certificateChains = getX509CertificateChain(clientID); + if (certificateChains == null || certificateChains.length < 4) { + LOGGER.severe(Common.addTag("[CertVerify] certificateChains is null or the length is not valid: < 4, " + + "please check!")); + return false; + } + X509Certificate localEquipCACert = certificateChains[2]; + // get subjectDN of local root equipment CA certificate + String localEquipCAName = localEquipCACert.getSubjectDN().getName(); + // get issueDN of client's equipment certificate + X509Certificate remoteEquipCert = x509Certificate[1]; + String equipIssueName = remoteEquipCert.getIssuerDN().getName(); + return localEquipCAName.equals(equipIssueName); + } + + // check whether the former certificate owner is the publisher of next one. + private static boolean verifyChain(String clientID, X509Certificate[] x509Certificates) { + if (x509Certificates == null || x509Certificates.length < 2) { + LOGGER.severe(Common.addTag("[CertVerify] certificateChains is null or the length is not valid: < 2, " + + "please check!")); + return false; + } + + // check remote equipment certificate + try { + X509Certificate[] certificateChains = getX509CertificateChain(clientID); + if (certificateChains == null || certificateChains.length < 3) { + LOGGER.severe(Common.addTag("[CertVerify] certificateChains is null or the length is not valid: < 3, " + + "please check!")); + return false; + } + X509Certificate localEquipCA = certificateChains[2]; + PublicKey publicKey = localEquipCA.getPublicKey(); + x509Certificates[1].verify(publicKey); + } catch (NoSuchProviderException | CertificateException | NoSuchAlgorithmException | + InvalidKeyException | SignatureException e) { + LOGGER.severe(Common.addTag("[CertVerify] catch Exception: " + e.getMessage())); + return false; + } + + // check remote service certificate + X509Certificate remoteEquipCert = x509Certificates[1]; + X509Certificate remoteServiceCert = x509Certificates[0]; + + try { + remoteEquipCert.checkValidity(); + remoteServiceCert.checkValidity(); + } catch (java.security.cert.CertificateExpiredException | + java.security.cert.CertificateNotYetValidException e) { + e.printStackTrace(); + return false; + } + + try { + PublicKey publicKey = remoteEquipCert.getPublicKey(); + remoteServiceCert.verify(publicKey); + } catch (CertificateException | NoSuchAlgorithmException | InvalidKeyException | + NoSuchProviderException | SignatureException e) { + LOGGER.severe(Common.addTag("verifyChain failed!")); + LOGGER.severe(Common.addTag("[verifyChain] catch Exception: " + e.getMessage())); + return false; + } + LOGGER.severe(Common.addTag("verifyChain success!")); + return true; + } + + /** + * get certificate chain according to clientID + * + * @param clientID clientID of this client + * @return certificate chain + */ + public static X509Certificate[] getX509CertificateChain(String clientID) { + if (clientID == null || clientID.isEmpty()) { + LOGGER.severe(Common.addTag("[CertVerify] the parameter clientID is null or empty, please check!")); + return null; + } + X509Certificate[] x509Certificates = null; + try { + Certificate[] certificates = null; + KeyStore keyStore = KeyStore.getInstance(CipherConsts.KEYSTORE_TYPE); + keyStore.load(null); + KeyStore.Entry entry = keyStore.getEntry(clientID, null); + if (entry == null || !(entry instanceof KeyStore.PrivateKeyEntry)) { + return null; + } + certificates = ((KeyStore.PrivateKeyEntry) entry).getCertificateChain(); + if (certificates == null) { + return null; + } + x509Certificates = (X509Certificate[]) certificates; + } catch (IOException | NoSuchAlgorithmException | UnrecoverableEntryException | KeyStoreException | + CertificateException e) { + LOGGER.severe(Common.addTag("[CertVerify] catch Exception: " + e.getMessage())); + } + return x509Certificates; + } + + /** + * transform pem format to X509 format + * + * @param pemCerts pem format certificate + * @return X509 format certificates + */ + public static X509Certificate[] transformPemArrayToX509Array(String[] pemCerts) { + if (pemCerts == null || pemCerts.length == 0) { + LOGGER.severe(Common.addTag("[CertVerify] pemCerts is null or empty, please check!")); + throw new IllegalArgumentException(); + } + int nSize = pemCerts.length; + X509Certificate[] x509Certificates = new X509Certificate[nSize]; + for (int i = 0; i < nSize; ++i) { + x509Certificates[i] = transformPemToX509(pemCerts[i]); + } + return x509Certificates; + } + + private static X509Certificate transformPemToX509(String pemCert) { + X509Certificate x509Certificate = null; + CertificateFactory cf; + try { + if (pemCert != null && !pemCert.trim().isEmpty()) { + byte[] certificateData = Base64.getDecoder().decode(pemCert); + cf = CertificateFactory.getInstance("X509"); + x509Certificate = (X509Certificate) cf.generateCertificate(new ByteArrayInputStream(certificateData)); + } + } catch (CertificateException e) { + LOGGER.severe(Common.addTag("[CertVerify] catch Exception: " + e.getMessage())); + return null; + } + return x509Certificate; + } + + private static boolean verifyCrl(String clientID, X509Certificate[] x509Certificates) { + if (x509Certificates == null || x509Certificates.length < 2) { + LOGGER.severe(Common.addTag("[verifyCrl] the number of certificate in x509Certificates is less than 2, " + + "please check!")); + throw new IllegalArgumentException(); + } + FLParameter flParameter = FLParameter.getInstance(); + X509Certificate equipCert = x509Certificates[1]; + if (equipCert == null) { + LOGGER.severe(Common.addTag("[verifyCrl] equipCert is null, please check it!")); + return false; + } + String equipCertSerialNum = equipCert.getSerialNumber().toString(); + if (verifySingleCrl(clientID, equipCertSerialNum, flParameter.getEquipCrlPath())) { + LOGGER.info(Common.addTag("[verifyCrl] verify crl certificate success!")); + return true; + } + LOGGER.info(Common.addTag("[verifyCrl] verify crl certificate failed!")); + return false; + } + + private static boolean verifySingleCrl(String clientID, String caSerialNumber, String crlPath) { + if (caSerialNumber == null || caSerialNumber.isEmpty()) { + LOGGER.severe(Common.addTag("[CertVerify] caSerialNumber is null or empty, please check!")); + throw new IllegalArgumentException(); + } + // crlPath does not exist + if (crlPath.equals("null")) { + LOGGER.severe(Common.addTag("[CertVerify] crlPath is null, please set crlPath with setEquipCrlPath " + + "method!")); + return false; + } + boolean notInFlag = true; + try { + X509CRL crl = (X509CRL) readCrl(crlPath); + if (crl != null) { + // check CRL cert with local equipment CA publicKey + X509Certificate[] certificateChains = getX509CertificateChain(clientID); + if (certificateChains == null || certificateChains.length < 3) { + LOGGER.severe(Common.addTag("[CertVerify] certificateChains is null or the length is not" + + " valid: < 3, please check!")); + return false; + } + X509Certificate localEquipCA = certificateChains[2]; + PublicKey publicKey = localEquipCA.getPublicKey(); + crl.verify(publicKey); + + // check whether remote equipmentCert in CRL + Set set = crl.getRevokedCertificates(); + if (set == null) { + LOGGER.info(Common.addTag("[verifySingleCrl] verifyCrl Revoked Cert list is null")); + return true; + } + for (Object obj : set) { + X509CRLEntry crlEntity = (X509CRLEntry) obj; + if (crlEntity.getSerialNumber().toString().equals(caSerialNumber)) { + LOGGER.info(Common.addTag("[verifySingleCrl] Find same SerialNumber during the crl!")); + notInFlag = false; + break; + } + } + } + } catch (java.security.cert.CRLException | java.security.NoSuchAlgorithmException | + java.security.InvalidKeyException | java.security.NoSuchProviderException | + java.security.SignatureException e) { + LOGGER.severe(Common.addTag("[verifySingleCrl] judgeCAInCRL error: " + e.getMessage())); + notInFlag = false; + } + return notInFlag; + } + + private static Object readCrl(String assetName) { + if (assetName == null || assetName.isEmpty()) { + LOGGER.severe(Common.addTag("[readCrl] the parameter of is null or empty, please check!")); + return null; + } + InputStream inputStream = null; + try { + inputStream = new FileInputStream(assetName); + } catch (IOException e) { + LOGGER.severe(Common.addTag("[readCrl] catch Exception of read inputStream in readCert: " + + e.getMessage())); + return null; + } + Object crlCert = null; + try { + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + crlCert = cf.generateCRL(inputStream); + } catch (CertificateException | CRLException e) { + LOGGER.severe(Common.addTag("[readCrl] catch Exception of creating CertificateFactory in readCert: " + + e.getMessage())); + } finally { + try { + inputStream.close(); + } catch (IOException e) { + LOGGER.severe(Common.addTag("[readCrl] catch Exception of close inputStream: " + + e.getMessage())); + } + } + + return crlCert; + } + + private static boolean verifyValidDate(X509Certificate[] x509Certificates) { + if (x509Certificates == null) { + LOGGER.severe(Common.addTag("[CertVerify] x509Certificates is null, please check!")); + throw new IllegalArgumentException(); + } + Date date = new Date(); + try { + int nSize = x509Certificates.length; + for (int i = 0; i < nSize; ++i) { + x509Certificates[i].checkValidity(date); + } + } catch (java.security.cert.CertificateExpiredException | + java.security.cert.CertificateNotYetValidException e) { + LOGGER.severe(Common.addTag("[verifyValidDate] catch Exception: " + e.getMessage())); + return false; + } + return true; + } + + private static boolean verifyKeyIdentifier(String clientID, X509Certificate[] x509Certificates) { + if (clientID == null || clientID.isEmpty()) { + LOGGER.severe(Common.addTag("[CertVerify] the parameter clientID is null or empty, please check!")); + return false; + } + if (x509Certificates == null || x509Certificates.length < 2) { + LOGGER.severe(Common.addTag("[CertVerify] x509Certificate chains is null or the length is not valid: < 2," + + " please check!")); + return false; + } + + X509Certificate[] certificateChains = getX509CertificateChain(clientID); + if (certificateChains == null || certificateChains.length < 3) { + LOGGER.severe(Common.addTag("[CertVerify] certificateChains is null or the length is not valid: < 3, " + + "please check!")); + return false; + } + + X509Certificate localEquipCACert = certificateChains[2]; + String subjectIdentifier = "null"; + + // get subjectKeyIdentifier of local root equipment CA certificate + try { + String subjectIdentifierOid = "2.5.29.14"; + byte[] subjectExtendData = localEquipCACert.getExtensionValue(subjectIdentifierOid); + ASN1OctetString asn1OctetString = ASN1OctetString.getInstance(subjectExtendData); + byte[] tmpData = asn1OctetString.getOctets(); + SubjectKeyIdentifier subjectKeyIdentifier = SubjectKeyIdentifier.getInstance(tmpData); + byte[] octKeyIdentifier = subjectKeyIdentifier.getKeyIdentifier(); + subjectIdentifier = new String(Hex.encode(octKeyIdentifier)); + } catch (ExceptionInInitializerError e) { + e.printStackTrace(); + } + + // get authorityKeyIdentifier of client's equipment certificate + X509Certificate remoteEquipCert = x509Certificates[1]; + String authorityIdentifier = "null"; + try { + if (remoteEquipCert == null) { + LOGGER.severe(Common.addTag("[CertVerify] remoteEquipCert is null, please check it!")); + return false; + } + String authorityIdentifierOid = "2.5.29.35"; + byte[] authExtendData = remoteEquipCert.getExtensionValue(authorityIdentifierOid); + ASN1OctetString asn1OctetString = ASN1OctetString.getInstance(authExtendData); + byte[] tmpData = asn1OctetString.getOctets(); + AuthorityKeyIdentifier authorityKeyIdentifier = AuthorityKeyIdentifier.getInstance(tmpData); + byte[] octKeyIdentifier = authorityKeyIdentifier.getKeyIdentifier(); + authorityIdentifier = new String(Hex.encode(octKeyIdentifier)); + } catch (ExceptionInInitializerError e) { + e.printStackTrace(); + } + + if (authorityIdentifier.equals("null") || subjectIdentifier.equals("null")) { + LOGGER.severe(Common.addTag("[CertVerify] authorityKeyIdentifier or subjectKeyIdentifier is null, check " + + "failed!")); + return false; + } else { + return authorityIdentifier.equals(subjectIdentifier); + } + } +} diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/CipherConsts.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/CipherConsts.java new file mode 100644 index 00000000000..159cfdcbe8c --- /dev/null +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/CipherConsts.java @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mindspore.flclient.cipher; + +/** + * Consts used for certification signature + * + * @since 2021-8-27 + */ +public class CipherConsts { + /** + * provider name + */ + public static final String PROVIDER_NAME = "HwUniversalKeyStoreProvider"; + + /** + * keyStore type + */ + public static final String KEYSTORE_TYPE = "HwKeyStore"; + + /** + * sign algorithm + */ + public static final String SIGN_ALGORITHM = "SHA256withRSA/PSS"; +} diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/ClientListReq.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/ClientListReq.java index f4b9b9b3c8f..494206ad448 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/ClientListReq.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/ClientListReq.java @@ -21,6 +21,7 @@ import static com.mindspore.flclient.FLParameter.SLEEP_TIME; import com.google.flatbuffers.FlatBufferBuilder; +import com.mindspore.flclient.CipherClient; import com.mindspore.flclient.Common; import com.mindspore.flclient.FLClientStatus; import com.mindspore.flclient.FLCommunication; @@ -91,12 +92,28 @@ public class ClientListReq { long timestamp = date.getTime(); String dateTime = String.valueOf(timestamp); int time = builder.createString(dateTime); - - GetClientList.startGetClientList(builder); - GetClientList.addFlId(builder, id); - GetClientList.addIteration(builder, iteration); - GetClientList.addTimestamp(builder, time); - int clientListRoot = GetClientList.endGetClientList(builder); + int clientListRoot; + byte[] signature = CipherClient.signTimeAndIter(dateTime, iteration); + if (signature == null) { + LOGGER.severe(Common.addTag("[getClientList] get signature is null!")); + return FLClientStatus.FAILED; + } + if (signature.length > 0) { + int signed = GetClientList.createSignatureVector(builder, signature); + GetClientList.startGetClientList(builder); + GetClientList.addFlId(builder, id); + GetClientList.addIteration(builder, iteration); + GetClientList.addTimestamp(builder, time); + GetClientList.addSignature(builder, signed); + clientListRoot = GetClientList.endGetClientList(builder); + } else { + GetClientList.startGetClientList(builder); + GetClientList.addFlId(builder, id); + GetClientList.addIteration(builder, iteration); + GetClientList.addTimestamp(builder, time); + GetClientList.addSignature(builder, 0); + clientListRoot = GetClientList.endGetClientList(builder); + } builder.finish(clientListRoot); byte[] msg = builder.sizedByteArray(); String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), flParameter.getDomainName()); @@ -107,6 +124,7 @@ public class ClientListReq { "request again")); Common.sleep(SLEEP_TIME); nextRequestTime = ""; + retCode = ResponseCode.OutOfTime; return FLClientStatus.RESTART; } ByteBuffer buffer = ByteBuffer.wrap(responseData); diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/ReconstructSecretReq.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/ReconstructSecretReq.java index 2cdfb63d0c6..74879b89735 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/ReconstructSecretReq.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/ReconstructSecretReq.java @@ -20,6 +20,7 @@ import static com.mindspore.flclient.FLParameter.SLEEP_TIME; import com.google.flatbuffers.FlatBufferBuilder; +import com.mindspore.flclient.CipherClient; import com.mindspore.flclient.Common; import com.mindspore.flclient.FLClientStatus; import com.mindspore.flclient.FLCommunication; @@ -68,7 +69,8 @@ public class ReconstructSecretReq { */ public FLClientStatus sendReconstructSecret(List decryptShareSecretsList, List u3ClientList, int iteration) { - String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), flParameter.getDomainName()); + String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), + flParameter.getDomainName()); FlatBufferBuilder builder = new FlatBufferBuilder(); int desFlId = builder.createString(localFLParameter.getFlID()); Date date = new Date(); @@ -106,12 +108,28 @@ public class ReconstructSecretReq { } int reconstructShareSecrets = SendReconstructSecret.createReconstructSecretSharesVector(builder, decryptShareList); - SendReconstructSecret.startSendReconstructSecret(builder); - SendReconstructSecret.addFlId(builder, desFlId); - SendReconstructSecret.addReconstructSecretShares(builder, reconstructShareSecrets); - SendReconstructSecret.addIteration(builder, iteration); - SendReconstructSecret.addTimestamp(builder, time); - int reconstructSecretRoot = SendReconstructSecret.endSendReconstructSecret(builder); + + int reconstructSecretRoot; + byte[] signature = CipherClient.signTimeAndIter(dateTime, iteration); + if (signature.length > 0) { + int signed = SendReconstructSecret.createSignatureVector(builder, signature); + SendReconstructSecret.startSendReconstructSecret(builder); + SendReconstructSecret.addFlId(builder, desFlId); + SendReconstructSecret.addReconstructSecretShares(builder, reconstructShareSecrets); + SendReconstructSecret.addIteration(builder, iteration); + SendReconstructSecret.addTimestamp(builder, time); + SendReconstructSecret.addSignature(builder, signed); + reconstructSecretRoot = SendReconstructSecret.endSendReconstructSecret(builder); + } else { + SendReconstructSecret.startSendReconstructSecret(builder); + SendReconstructSecret.addFlId(builder, desFlId); + SendReconstructSecret.addReconstructSecretShares(builder, reconstructShareSecrets); + SendReconstructSecret.addIteration(builder, iteration); + SendReconstructSecret.addTimestamp(builder, time); + SendReconstructSecret.addSignature(builder, 0); + reconstructSecretRoot = SendReconstructSecret.endSendReconstructSecret(builder); + } + builder.finish(reconstructSecretRoot); byte[] msg = builder.sizedByteArray(); try { @@ -121,6 +139,7 @@ public class ReconstructSecretReq { "time and request again")); Common.sleep(SLEEP_TIME); nextRequestTime = ""; + retCode = ResponseCode.OutOfTime; return FLClientStatus.RESTART; } ByteBuffer buffer = ByteBuffer.wrap(responseData); diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/SignAndVerify.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/SignAndVerify.java new file mode 100644 index 00000000000..65ada55e006 --- /dev/null +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/SignAndVerify.java @@ -0,0 +1,152 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore.flclient.cipher; + +import com.mindspore.flclient.Common; + +import java.io.IOException; +import java.security.InvalidKeyException; +import java.security.Key; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.security.NoSuchProviderException; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.security.Signature; +import java.security.SignatureException; +import java.security.UnrecoverableKeyException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.logging.Logger; + +/** + * class used for sign data and verify data + * + * @since 2021-8-27 + */ +public class SignAndVerify { + private static final Logger LOGGER = Logger.getLogger(SignAndVerify.class.toString()); + + /** + * sign data + * + * @param clientID ID of this client + * @param data data need to be signed + * @return signed data + */ + public static byte[] signData(String clientID, byte[] data) { + if (clientID == null || clientID.isEmpty()) { + LOGGER.severe(Common.addTag("[SignAndVerify] the parameter clientID is null or empty, please check!")); + return null; + } + if (data == null || data.length == 0) { + LOGGER.severe(Common.addTag("[SignAndVerify] the parameter data is null or empty, please check!")); + return null; + } + byte[] signData = null; + try { + KeyStore ks = KeyStore.getInstance(CipherConsts.KEYSTORE_TYPE); + ks.load(null); + Key privateKey = ks.getKey(clientID, null); + if (privateKey == null) { + LOGGER.info("private key is null"); + return null; + } + Signature signature = Signature.getInstance(CipherConsts.SIGN_ALGORITHM, CipherConsts.PROVIDER_NAME); + signature.initSign((PrivateKey) privateKey); + signature.update(data); + signData = signature.sign(); + } catch (KeyStoreException | CertificateException | NoSuchAlgorithmException | IOException + | UnrecoverableKeyException | NoSuchProviderException | InvalidKeyException + | SignatureException e) { + LOGGER.severe(Common.addTag("[SignAndVerify] catch Exception: " + e.getMessage())); + } + return signData; + } + + /** + * verify signature with certifications + * + * @param clientID ID of this client + * @param x509Certificates certificates + * @param data original data + * @param signed signed data + * @return verify result + */ + public static boolean verifySignatureByCert(String clientID, X509Certificate[] x509Certificates, byte[] data, + byte[] signed) { + if (clientID == null || clientID.isEmpty()) { + LOGGER.severe(Common.addTag("[SignAndVerify] the parameter clientID is null or empty, please check!")); + return false; + } + if (x509Certificates == null || x509Certificates.length < 1) { + LOGGER.severe(Common.addTag("[SignAndVerify] the parameter x509Certificates is null or the length is not " + + "valid: < 1, please check!")); + return false; + } + if (data == null || data.length == 0) { + LOGGER.severe(Common.addTag("[SignAndVerify] the parameter data is null or empty, please check!")); + return false; + } + if (signed == null || signed.length == 0) { + LOGGER.severe(Common.addTag("[SignAndVerify] the parameter signed is null or empty, please check!")); + return false; + } + if (!CertVerify.verifyCertificateChain(clientID, x509Certificates)) { + LOGGER.info(Common.addTag("Verify chain failed!")); + return false; + } + LOGGER.info(Common.addTag("Verify chain success!")); + + boolean isValid; + try { + if (x509Certificates[0].getPublicKey() == null) { + LOGGER.severe(Common.addTag("[SignAndVerify] get public key failed!")); + return false; + } + PublicKey publicKey = x509Certificates[0].getPublicKey(); // get public key + Signature signature = Signature.getInstance(CipherConsts.SIGN_ALGORITHM); + signature.initVerify(publicKey); // set public key + signature.update(data); // set data + isValid = signature.verify(signed); // verify the consistence between signature and data + } catch (NoSuchAlgorithmException | SignatureException | InvalidKeyException e) { + LOGGER.severe(Common.addTag("[SignAndVerify] catch Exception: " + e.getMessage())); + return false; + } + return isValid; + } + + /** + * get hash result of bytes + * + * @param bytes inputs + * @return hash value of bytes + */ + public static byte[] getSHA256(byte[] bytes) { + MessageDigest messageDigest; + byte[] hash = new byte[0]; + try { + messageDigest = MessageDigest.getInstance("SHA-256"); + hash = messageDigest.digest(bytes); + } catch (NoSuchAlgorithmException e) { + LOGGER.severe(Common.addTag("[PkiUtil] catch NoSuchAlgorithmException: " + e.getMessage())); + } + return hash; + } +} diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/struct/ClientPublicKey.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/struct/ClientPublicKey.java index c2b5ab6924a..a265099b390 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/struct/ClientPublicKey.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/struct/ClientPublicKey.java @@ -40,7 +40,8 @@ public class ClientPublicKey { */ public String getFlID() { if (flID == null || flID.isEmpty()) { - LOGGER.severe(Common.addTag("[ClientPublicKey] the parameter of is null, please set it before use")); + LOGGER.severe(Common.addTag("[ClientPublicKey] the parameter of is null, please set it before " + + "using")); throw new IllegalArgumentException(); } return flID; diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/struct/EncryptShare.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/struct/EncryptShare.java index 236edc551fe..c3a1a1832e0 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/struct/EncryptShare.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/struct/EncryptShare.java @@ -38,7 +38,7 @@ public class EncryptShare { public String getFlID() { if (flID == null || flID.isEmpty()) { LOGGER.severe(Common.addTag("[DecryptShareSecrets] the parameter of is null, please set it before " + - "use")); + "using")); throw new IllegalArgumentException(); } return flID; diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/struct/NewArray.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/struct/NewArray.java index 28a06c78bc7..7a7678b818f 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/struct/NewArray.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/struct/NewArray.java @@ -20,7 +20,6 @@ package com.mindspore.flclient.cipher.struct; * class used define new array type * * @param an array - * * @since 2021-8-27 */ public class NewArray { diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/struct/ShareSecret.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/struct/ShareSecret.java index b22d9df5f31..d5385bc18ea 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/struct/ShareSecret.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/cipher/struct/ShareSecret.java @@ -38,7 +38,7 @@ public class ShareSecret { */ public String getFlID() { if (flID == null || flID.isEmpty()) { - LOGGER.severe(Common.addTag("[ShareSecret] the parameter of is null, please set it before use")); + LOGGER.severe(Common.addTag("[ShareSecret] the parameter of is null, please set it before using")); throw new IllegalArgumentException(); } return flID; diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/pki/PkiBean.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/pki/PkiBean.java new file mode 100644 index 00000000000..01445eae402 --- /dev/null +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/pki/PkiBean.java @@ -0,0 +1,39 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved. + */ + +package com.mindspore.flclient.pki; + +import java.security.cert.Certificate; + +/** + * PkiBean entity + * + * @since 2021-08-25 + */ +public class PkiBean { + private byte[] signData; + + private Certificate[] certificates; + + public PkiBean(byte[] signData, Certificate[] certificates) { + this.signData = signData; + this.certificates = certificates; + } + + public byte[] getSignData() { + return signData; + } + + public void setSignData(byte[] signData) { + this.signData = signData; + } + + public Certificate[] getCertificates() { + return certificates; + } + + public void setCertificates(Certificate[] certificates) { + this.certificates = certificates; + } +} diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/pki/PkiConsts.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/pki/PkiConsts.java new file mode 100644 index 00000000000..948fccfcbd1 --- /dev/null +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/pki/PkiConsts.java @@ -0,0 +1,27 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved. + */ + +package com.mindspore.flclient.pki; + +/** + * Pki consts + * + * @since 2021-08-25 + */ +public class PkiConsts { + /** + * PROVIDER_NAME + */ + public static final String PROVIDER_NAME = "HwUniversalKeyStoreProvider"; + + /** + * KEYSTORE_TYPE + */ + public static final String KEYSTORE_TYPE = "HwKeyStore"; + + /** + * ALGORITHM + */ + public static final String ALGORITHM = "SHA256withRSA/PSS"; +} diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/pki/PkiUtil.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/pki/PkiUtil.java new file mode 100644 index 00000000000..4d9600fd5a2 --- /dev/null +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/pki/PkiUtil.java @@ -0,0 +1,247 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved. + */ + +package com.mindspore.flclient.pki; + +import com.mindspore.flclient.Common; +import com.mindspore.flclient.LocalFLParameter; + +import org.bouncycastle.util.io.pem.PemObject; +import org.bouncycastle.util.io.pem.PemWriter; + +import java.io.IOException; +import java.io.StringWriter; +import java.nio.charset.StandardCharsets; +import java.security.InvalidKeyException; +import java.security.Key; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.security.NoSuchProviderException; +import java.security.PrivateKey; +import java.security.Signature; +import java.security.SignatureException; +import java.security.UnrecoverableEntryException; +import java.security.UnrecoverableKeyException; +import java.security.cert.Certificate; +import java.security.cert.CertificateEncodingException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Locale; +import java.util.logging.Logger; + +/** + * Pki Util + * + * @since 2021-08-25 + */ +public class PkiUtil { + private static final Logger LOGGER = Logger.getLogger(PkiUtil.class.toString()); + + /** + * generate PkiBean + * + * @param clientID String + * @param time long + * @return PkiBean + */ + public static PkiBean genPkiBean(String clientID, long time) { + String sourceData = LocalFLParameter.getInstance().getFlID() + " " + time; + byte[] signDataBytes = PkiUtil.signData(clientID, + sourceData.getBytes(StandardCharsets.UTF_8)); + + Certificate[] certificates = PkiUtil.getCertificateChain(clientID); + + return new PkiBean(signDataBytes, certificates); + } + + /** + * get str of SHA256 + * + * @param str String + * @return hash + */ + public static byte[] getSHA256Str(String str) { + MessageDigest messageDigest; + byte[] hash = new byte[0]; + try { + messageDigest = MessageDigest.getInstance("SHA-256"); + hash = messageDigest.digest(str.getBytes(StandardCharsets.UTF_8)); + } catch (NoSuchAlgorithmException e) { + LOGGER.severe(Common.addTag("[PkiUtil] catch NoSuchAlgorithmException: " + e.getMessage())); + } + return hash; + } + + /** + * get Pem format + * + * @param certificate Certificate + * @return String result + * @throws IOException e + */ + public static String getPemFormat(Certificate certificate) throws IOException { + StringWriter writer = new StringWriter(); + PemWriter pemWriter = new PemWriter(writer); + try { + pemWriter.writeObject(new PemObject("CERTIFICATE", certificate.getEncoded())); + } catch (IOException | CertificateEncodingException e) { + LOGGER.severe(Common.addTag("[PkiUtil] catch IOException or CertificateEncodingException in getPermFormat: " + + e.getMessage())); + } finally { + pemWriter.flush(); + pemWriter.close(); + } + + return writer.toString(); + } + + private static byte[] signData(String clientID, byte[] data) { + byte[] signData = null; + try { + KeyStore ks = KeyStore.getInstance(PkiConsts.KEYSTORE_TYPE); + ks.load(null); + Key privateKey = ks.getKey(clientID, null); + if (privateKey == null) { + return new byte[0]; + } + Signature signature = Signature.getInstance(PkiConsts.ALGORITHM, PkiConsts.PROVIDER_NAME); + if (privateKey instanceof PrivateKey) { + signature.initSign((PrivateKey) privateKey); + } + signature.update(data); + signData = signature.sign(); + } catch (KeyStoreException | CertificateException | NoSuchAlgorithmException | IOException + | UnrecoverableKeyException | NoSuchProviderException | InvalidKeyException + | SignatureException e) { + LOGGER.severe(Common.addTag("[PkiUtil] catch Exception: " + e.getMessage())); + } + return signData; + } + + /** + * get certificate chain + * + * @param clientID String + * @return Certificate[] + */ + public static Certificate[] getCertificateChain(String clientID) { + Certificate[] certificates = null; + try { + KeyStore keyStore = KeyStore.getInstance(PkiConsts.KEYSTORE_TYPE); + keyStore.load(null); + + KeyStore.Entry entry = keyStore.getEntry(clientID, null); + if (entry == null) { + return new Certificate[0]; + } + + if (!(entry instanceof KeyStore.PrivateKeyEntry)) { + return new Certificate[0]; + } + + certificates = ((KeyStore.PrivateKeyEntry) entry).getCertificateChain(); + } catch (IOException | CertificateException | NoSuchAlgorithmException + | UnrecoverableEntryException | KeyStoreException e) { + LOGGER.severe(Common.addTag("[PkiUtil] catch Exception: " + e.getMessage())); + } + + return certificates; + } + + /** + * to hex format + * + * @param data byte[] + * @return String + */ + public static String toHexFormat(byte[] data) { + if (data == null || data.length == 0) { + return ""; + } + StringBuilder sb = new StringBuilder(); + for (byte byteData : data) { + sb.append(String.format(Locale.ROOT, "%02x", byteData)); + } + + return sb.toString(); + } + + /** + * gen equip cert hash + * + * @param clientID String + * @return String + */ + public static String genEquipCertHash(String clientID) { + String equipCert; + byte[] equipCertBytesHash = null; + try { + Certificate[] certificates = getCertificateChain(clientID); + if (certificates == null || certificates.length < 2) { + return ""; + } + equipCert = readPemFormat(certificates[1]); + equipCertBytesHash = getSHA256Str(equipCert); + } catch (IOException e) { + LOGGER.severe(Common.addTag("[PkiUtil] catch Exception: " + e.getMessage())); + } + + return toHexFormat(equipCertBytesHash); + } + + /** + * generate hash from cer + * + * @param certificateGradeOne X509Certificate + * @return String + */ + public static String genHashFromCer(X509Certificate certificateGradeOne) { + String equipCert = null; + byte[] equipCertBytesHash = null; + try { + equipCert = readPemFormat(certificateGradeOne); + equipCertBytesHash = getSHA256Str(equipCert); + } catch (IOException e) { + LOGGER.severe(Common.addTag("[PkiUtil] catch Exception: " + e.getMessage())); + } + if (equipCertBytesHash == null) { + return ""; + } + StringBuilder sb = new StringBuilder(); + for (byte byteData : equipCertBytesHash) { + sb.append(String.format(Locale.ROOT, "%02x", byteData)); + } + return sb.toString(); + } + + /** + * read pem format + * + * @param certificate Certificate + * @return String result + * @throws IOException e + */ + public static String readPemFormat(Certificate certificate) throws IOException { + StringWriter writer = new StringWriter(); + PemWriter pemWriter = new PemWriter(writer); + if (certificate == null) { + LOGGER.severe(Common.addTag("[PkiUtil] the input parameter certificate is null, please check")); + throw new IllegalArgumentException(); + } + try { + pemWriter.writeObject(new PemObject("CERTIFICATE", certificate.getEncoded())); + } catch (IOException | CertificateEncodingException e) { + LOGGER.severe( + Common.addTag("[PkiUtil] catch IOException or CertificateEncodingException in getPermFormat: " + + e.getMessage())); + } finally { + pemWriter.flush(); + pemWriter.close(); + } + + return writer.toString(); + } +} \ No newline at end of file