change dir and add Encrypt code
This commit is contained in:
parent
f98497ca09
commit
b7b58a2ab9
|
@ -0,0 +1,578 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import com.mindspore.flclient.cipher.AESEncrypt;
|
||||
import com.mindspore.flclient.cipher.BaseUtil;
|
||||
import com.mindspore.flclient.cipher.ClientListReq;
|
||||
import com.mindspore.flclient.cipher.KEYAgreement;
|
||||
import com.mindspore.flclient.cipher.Random;
|
||||
import com.mindspore.flclient.cipher.ReconstructSecretReq;
|
||||
import com.mindspore.flclient.cipher.ShareSecrets;
|
||||
import com.mindspore.flclient.cipher.struct.ClientPublicKey;
|
||||
import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
|
||||
import com.mindspore.flclient.cipher.struct.EncryptShare;
|
||||
import com.mindspore.flclient.cipher.struct.NewArray;
|
||||
import com.mindspore.flclient.cipher.struct.ShareSecret;
|
||||
import mindspore.schema.ClientShare;
|
||||
import mindspore.schema.GetExchangeKeys;
|
||||
import mindspore.schema.GetShareSecrets;
|
||||
import mindspore.schema.RequestExchangeKeys;
|
||||
import mindspore.schema.RequestShareSecrets;
|
||||
import mindspore.schema.ResponseCode;
|
||||
import mindspore.schema.ResponseExchangeKeys;
|
||||
import mindspore.schema.ResponseShareSecrets;
|
||||
import mindspore.schema.ReturnExchangeKeys;
|
||||
import mindspore.schema.ReturnShareSecrets;
|
||||
|
||||
import java.io.UnsupportedEncodingException;
|
||||
import java.math.BigInteger;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.security.spec.InvalidKeySpecException;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||
import static com.mindspore.flclient.LocalFLParameter.IVEC_LEN;
|
||||
import static com.mindspore.flclient.LocalFLParameter.SEED_SIZE;
|
||||
|
||||
public class CipherClient {
|
||||
private static final Logger LOGGER = Logger.getLogger(CipherClient.class.toString());
|
||||
private FLCommunication flCommunication;
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private final int iteration;
|
||||
private int featureSize;
|
||||
private int t;
|
||||
private List<byte[]> cKey = new ArrayList<>();
|
||||
private List<byte[]> sKey = new ArrayList<>();
|
||||
private byte[] bu;
|
||||
private String nextRequestTime;
|
||||
private Map<String, ClientPublicKey> clientPublicKeyList = new HashMap<String, ClientPublicKey>();
|
||||
private Map<String, byte[]> sUVKeys = new HashMap<String, byte[]>();
|
||||
private Map<String, byte[]> cUVKeys = new HashMap<String, byte[]>();
|
||||
private List<EncryptShare> clientShareList = new ArrayList<>();
|
||||
private List<EncryptShare> returnShareList = new ArrayList<>();
|
||||
private float[] featureMask;
|
||||
private List<String> u1ClientList = new ArrayList<>();
|
||||
private List<String> u2UClientList = new ArrayList<>();
|
||||
private List<String> u3ClientList = new ArrayList<>();
|
||||
private List<DecryptShareSecrets> decryptShareSecretsList = new ArrayList<>();
|
||||
private byte[] prime;
|
||||
private KEYAgreement keyAgreement = new KEYAgreement();
|
||||
private Random random = new Random();
|
||||
private ClientListReq clientListReq = new ClientListReq();
|
||||
private ReconstructSecretReq reconstructSecretReq = new ReconstructSecretReq();
|
||||
|
||||
public CipherClient(int iter, int minSecretNum, byte[] prime, int featureSize) {
|
||||
flCommunication = FLCommunication.getInstance();
|
||||
this.iteration = iter;
|
||||
this.featureSize = featureSize;
|
||||
this.t = minSecretNum;
|
||||
this.prime = prime;
|
||||
this.featureMask = new float[this.featureSize];
|
||||
}
|
||||
|
||||
public void setNextRequestTime(String nextRequestTime) {
|
||||
this.nextRequestTime = nextRequestTime;
|
||||
}
|
||||
|
||||
public void setBU(byte[] bu) {
|
||||
this.bu = bu;
|
||||
}
|
||||
|
||||
public void setClientShareList(List<EncryptShare> clientShareList) {
|
||||
this.clientShareList.clear();
|
||||
this.clientShareList = clientShareList;
|
||||
}
|
||||
|
||||
public String getNextRequestTime() {
|
||||
return nextRequestTime;
|
||||
}
|
||||
|
||||
public void genDHKeyPairs() {
|
||||
byte[] csk = keyAgreement.generatePrivateKey();
|
||||
byte[] cpk = keyAgreement.generatePublicKey(csk);
|
||||
byte[] ssk = keyAgreement.generatePrivateKey();
|
||||
byte[] spk = keyAgreement.generatePublicKey(ssk);
|
||||
this.cKey.add(cpk);
|
||||
this.cKey.add(csk);
|
||||
this.sKey.add(spk);
|
||||
this.sKey.add(ssk);
|
||||
}
|
||||
|
||||
public void genIndividualSecret() {
|
||||
byte[] key = new byte[SEED_SIZE];
|
||||
random.getRandomBytes(key);
|
||||
setBU(key);
|
||||
}
|
||||
|
||||
public List<ShareSecret> genSecretShares(byte[] secret) throws UnsupportedEncodingException {
|
||||
List<ShareSecret> shareSecretList = new ArrayList<>();
|
||||
int size = u1ClientList.size();
|
||||
ShareSecrets shamir = new ShareSecrets(t, size - 1);
|
||||
ShareSecrets.SecretShare[] shares = shamir.split(secret, prime);
|
||||
int j = 0;
|
||||
for (int i = 0; i < size; i++) {
|
||||
String vFlID = u1ClientList.get(i);
|
||||
if (localFLParameter.getFlID().equals(vFlID)) {
|
||||
continue;
|
||||
} else {
|
||||
ShareSecret shareSecret = new ShareSecret();
|
||||
NewArray<byte[]> array = new NewArray<>();
|
||||
int index = shares[j].getNum();
|
||||
BigInteger intShare = shares[j].getShare();
|
||||
byte[] share = BaseUtil.bigInteger2byteArray(intShare);
|
||||
array.setSize(share.length);
|
||||
array.setArray(share);
|
||||
shareSecret.setFlID(vFlID);
|
||||
shareSecret.setShare(array);
|
||||
shareSecret.setIndex(index);
|
||||
shareSecretList.add(shareSecret);
|
||||
j += 1;
|
||||
}
|
||||
}
|
||||
return shareSecretList;
|
||||
}
|
||||
|
||||
public void genEncryptExchangedKeys() throws InvalidKeySpecException, NoSuchAlgorithmException {
|
||||
cUVKeys.clear();
|
||||
for (String key : clientPublicKeyList.keySet()) {
|
||||
ClientPublicKey curPublicKey = clientPublicKeyList.get(key);
|
||||
String vFlID = curPublicKey.getFlID();
|
||||
if (localFLParameter.getFlID().equals(vFlID)) {
|
||||
continue;
|
||||
} else {
|
||||
byte[] secret1 = keyAgreement.keyAgreement(cKey.get(1), curPublicKey.getCPK().getArray());
|
||||
byte[] salt = new byte[0];
|
||||
byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt);
|
||||
cUVKeys.put(vFlID, secret);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void encryptShares() throws Exception {
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ************** generate encrypt share secrets for RequestShareSecrets **************"));
|
||||
List<EncryptShare> encryptShareList = new ArrayList<>();
|
||||
// connect sSkUv, bUV, sIndex, indexB and then Encrypt them
|
||||
List<ShareSecret> sSkUv = genSecretShares(sKey.get(1));
|
||||
List<ShareSecret> bUV = genSecretShares(bu);
|
||||
for (int i = 0; i < bUV.size(); i++) {
|
||||
EncryptShare encryptShare = new EncryptShare();
|
||||
NewArray<byte[]> array = new NewArray<>();
|
||||
String vFlID = bUV.get(i).getFlID();
|
||||
byte[] sShare = sSkUv.get(i).getShare().getArray();
|
||||
byte[] bShare = bUV.get(i).getShare().getArray();
|
||||
byte[] sIndex = BaseUtil.integer2byteArray(sSkUv.get(i).getIndex());
|
||||
byte[] bIndex = BaseUtil.integer2byteArray(bUV.get(i).getIndex());
|
||||
byte[] allSecret = new byte[sShare.length + bShare.length + sIndex.length + bIndex.length + 4];
|
||||
allSecret[0] = (byte) sShare.length;
|
||||
allSecret[1] = (byte) bShare.length;
|
||||
allSecret[2] = (byte) sIndex.length;
|
||||
allSecret[3] = (byte) bIndex.length;
|
||||
System.arraycopy(sIndex, 0, allSecret, 4, sIndex.length);
|
||||
System.arraycopy(bIndex, 0, allSecret, 4 + sIndex.length, bIndex.length);
|
||||
System.arraycopy(sShare, 0, allSecret, 4 + sIndex.length + bIndex.length, sShare.length);
|
||||
System.arraycopy(bShare, 0, allSecret, 4 + sIndex.length + bIndex.length + sShare.length, bShare.length);
|
||||
// encrypt:
|
||||
byte[] iVecIn = new byte[IVEC_LEN];
|
||||
AESEncrypt aesEncrypt = new AESEncrypt(cUVKeys.get(vFlID), iVecIn, "CBC");
|
||||
byte[] encryptData = aesEncrypt.encrypt(cUVKeys.get(vFlID), allSecret);
|
||||
array.setSize(encryptData.length);
|
||||
array.setArray(encryptData);
|
||||
encryptShare.setFlID(vFlID);
|
||||
encryptShare.setShare(array);
|
||||
encryptShareList.add(encryptShare);
|
||||
}
|
||||
setClientShareList(encryptShareList);
|
||||
}
|
||||
|
||||
public float[] doubleMaskingWeight() throws Exception {
|
||||
int size = u2UClientList.size();
|
||||
List<Float> noiseBu = new ArrayList<>();
|
||||
random.randomAESCTR(noiseBu, featureSize, bu);
|
||||
float[] mask = new float[featureSize];
|
||||
for (int i = 0; i < size; i++) {
|
||||
String vFlID = u2UClientList.get(i);
|
||||
ClientPublicKey curPublicKey = clientPublicKeyList.get(vFlID);
|
||||
if (localFLParameter.getFlID().equals(vFlID)) {
|
||||
continue;
|
||||
} else {
|
||||
byte[] salt = new byte[0];
|
||||
byte[] secret1 = keyAgreement.keyAgreement(sKey.get(1), curPublicKey.getSPK().getArray());
|
||||
byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt);
|
||||
sUVKeys.put(vFlID, secret);
|
||||
List<Float> noiseSuv = new ArrayList<>();
|
||||
random.randomAESCTR(noiseSuv, featureSize, secret);
|
||||
int sign;
|
||||
if (localFLParameter.getFlID().compareTo(vFlID) > 0) {
|
||||
sign = 1;
|
||||
} else {
|
||||
sign = -1;
|
||||
}
|
||||
for (int j = 0; j < noiseSuv.size(); j++) {
|
||||
mask[j] = mask[j] + sign * noiseSuv.get(j);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < noiseBu.size(); j++) {
|
||||
mask[j] = mask[j] + noiseBu.get(j);
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
public NewArray<byte[]> byteToArray(ByteBuffer buf, int size) {
|
||||
NewArray<byte[]> newArray = new NewArray<>();
|
||||
newArray.setSize(size);
|
||||
byte[] array = new byte[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
byte word = buf.get();
|
||||
array[i] = word;
|
||||
}
|
||||
newArray.setArray(array);
|
||||
return newArray;
|
||||
}
|
||||
|
||||
public FLClientStatus requestExchangeKeys() {
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() + "=============="));
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ==============requestExchangeKeys url: " + url + "=============="));
|
||||
genDHKeyPairs();
|
||||
byte[] cPK = cKey.get(0);
|
||||
byte[] sPK = sKey.get(0);
|
||||
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
||||
int id = fbBuilder.createString(localFLParameter.getFlID());
|
||||
int cpk = RequestExchangeKeys.createCPkVector(fbBuilder, cPK);
|
||||
int spk = RequestExchangeKeys.createSPkVector(fbBuilder, sPK);
|
||||
String dateTime = LocalDateTime.now().toString();
|
||||
int time = fbBuilder.createString(dateTime);
|
||||
int exchangeKeysRoot = RequestExchangeKeys.createRequestExchangeKeys(fbBuilder, id, cpk, spk, iteration, time);
|
||||
fbBuilder.finish(exchangeKeysRoot);
|
||||
byte[] msg = fbBuilder.sizedByteArray();
|
||||
try {
|
||||
byte[] responseData = flCommunication.syncRequest(url + "/exchangeKeys", msg);
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
ResponseExchangeKeys responseExchangeKeys = ResponseExchangeKeys.getRootAsResponseExchangeKeys(buffer);
|
||||
FLClientStatus status = judgeRequestExchangeKeys(responseExchangeKeys);
|
||||
return status;
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus judgeRequestExchangeKeys(ResponseExchangeKeys bufData) {
|
||||
int retcode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestExchangeKeys**************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
|
||||
switch (retcode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys success"));
|
||||
return FLClientStatus.SUCCESS;
|
||||
case (ResponseCode.OutOfTime):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys out of time: need wait and request startFLJob again"));
|
||||
setNextRequestTime(bufData.nextReqTime());
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
case (ResponseCode.SystemError):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] catch RequestError or SystemError in RequestExchangeKeys"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ResponseExchangeKeys is invalid: " + retcode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus getExchangeKeys() {
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ==============getExchangeKeys url: " + url + "=============="));
|
||||
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
||||
int id = fbBuilder.createString(localFLParameter.getFlID());
|
||||
String dateTime = LocalDateTime.now().toString();
|
||||
int time = fbBuilder.createString(dateTime);
|
||||
int getExchangeKeysRoot = GetExchangeKeys.createGetExchangeKeys(fbBuilder, id, iteration, time);
|
||||
fbBuilder.finish(getExchangeKeysRoot);
|
||||
byte[] msg = fbBuilder.sizedByteArray();
|
||||
try {
|
||||
byte[] responseData = flCommunication.syncRequest(url + "/getKeys", msg);
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
ReturnExchangeKeys returnExchangeKeys = ReturnExchangeKeys.getRootAsReturnExchangeKeys(buffer);
|
||||
FLClientStatus status = judgeGetExchangeKeys(returnExchangeKeys);
|
||||
return status;
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus judgeGetExchangeKeys(ReturnExchangeKeys bufData) {
|
||||
int retcode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetExchangeKeys**************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
|
||||
switch (retcode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] GetExchangeKeys success"));
|
||||
clientPublicKeyList.clear();
|
||||
u1ClientList.clear();
|
||||
int length = bufData.remotePublickeysLength();
|
||||
for (int i = 0; i < length; i++) {
|
||||
ClientPublicKey publicKey = new ClientPublicKey();
|
||||
String srcFlId = bufData.remotePublickeys(i).flId();
|
||||
publicKey.setFlID(srcFlId);
|
||||
ByteBuffer bufCpk = bufData.remotePublickeys(i).cPkAsByteBuffer();
|
||||
int sizeCpk = bufData.remotePublickeys(i).cPkLength();
|
||||
ByteBuffer bufSpk = bufData.remotePublickeys(i).sPkAsByteBuffer();
|
||||
int sizeSpk = bufData.remotePublickeys(i).sPkLength();
|
||||
publicKey.setCPK(byteToArray(bufCpk, sizeCpk));
|
||||
publicKey.setSPK(byteToArray(bufSpk, sizeSpk));
|
||||
clientPublicKeyList.put(srcFlId, publicKey);
|
||||
u1ClientList.add(srcFlId);
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
case (ResponseCode.SucNotReady):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetExchangeKeys again!"));
|
||||
return FLClientStatus.WAIT;
|
||||
case (ResponseCode.OutOfTime):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] GetExchangeKeys out of time: need wait and request startFLJob again"));
|
||||
setNextRequestTime(bufData.nextReqTime());
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
case (ResponseCode.SystemError):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetExchangeKeys"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReturnExchangeKeys is invalid: " + retcode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus requestShareSecrets() throws Exception {
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ==============requestShareSecrets url: " + url + "=============="));
|
||||
genIndividualSecret();
|
||||
genEncryptExchangedKeys();
|
||||
encryptShares();
|
||||
|
||||
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
||||
int id = fbBuilder.createString(localFLParameter.getFlID());
|
||||
String dateTime = LocalDateTime.now().toString();
|
||||
int time = fbBuilder.createString(dateTime);
|
||||
int clientShareSize = clientShareList.size();
|
||||
if (clientShareSize <= 0) {
|
||||
LOGGER.warning(Common.addTag("[PairWiseMask] encrypt shares is not ready now!"));
|
||||
Common.sleep(SLEEP_TIME);
|
||||
FLClientStatus status = requestShareSecrets();
|
||||
return status;
|
||||
} else {
|
||||
int[] add = new int[clientShareSize];
|
||||
for (int i = 0; i < clientShareSize; i++) {
|
||||
int flID = fbBuilder.createString(clientShareList.get(i).getFlID());
|
||||
int shareSecretFbs = ClientShare.createShareVector(fbBuilder, clientShareList.get(i).getShare().getArray());
|
||||
ClientShare.startClientShare(fbBuilder);
|
||||
ClientShare.addFlId(fbBuilder, flID);
|
||||
ClientShare.addShare(fbBuilder, shareSecretFbs);
|
||||
int clientShareRoot = ClientShare.endClientShare(fbBuilder);
|
||||
add[i] = clientShareRoot;
|
||||
}
|
||||
int encryptedSharesFbs = RequestShareSecrets.createEncryptedSharesVector(fbBuilder, add);
|
||||
int requestShareSecretsRoot = RequestShareSecrets.createRequestShareSecrets(fbBuilder, id, encryptedSharesFbs, iteration, time);
|
||||
fbBuilder.finish(requestShareSecretsRoot);
|
||||
byte[] msg = fbBuilder.sizedByteArray();
|
||||
try {
|
||||
byte[] responseData = flCommunication.syncRequest(url + "/shareSecrets", msg);
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
ResponseShareSecrets responseShareSecrets = ResponseShareSecrets.getRootAsResponseShareSecrets(buffer);
|
||||
FLClientStatus status = judgeRequestShareSecrets(responseShareSecrets);
|
||||
return status;
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus judgeRequestShareSecrets(ResponseShareSecrets bufData) {
|
||||
int retcode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestShareSecrets**************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
|
||||
switch (retcode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets success"));
|
||||
return FLClientStatus.SUCCESS;
|
||||
case (ResponseCode.OutOfTime):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets out of time: need wait and request startFLJob again"));
|
||||
setNextRequestTime(bufData.nextReqTime());
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
case (ResponseCode.SystemError):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in RequestShareSecrets"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ResponseShareSecrets is invalid: " + retcode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus getShareSecrets() {
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ==============getShareSecrets url: " + url + "=============="));
|
||||
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
||||
int id = fbBuilder.createString(localFLParameter.getFlID());
|
||||
String dateTime = LocalDateTime.now().toString();
|
||||
int time = fbBuilder.createString(dateTime);
|
||||
int getShareSecrets = GetShareSecrets.createGetShareSecrets(fbBuilder, id, iteration, time);
|
||||
fbBuilder.finish(getShareSecrets);
|
||||
byte[] msg = fbBuilder.sizedByteArray();
|
||||
try {
|
||||
byte[] responseData = flCommunication.syncRequest(url + "/getSecrets", msg);
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
ReturnShareSecrets returnShareSecrets = ReturnShareSecrets.getRootAsReturnShareSecrets(buffer);
|
||||
FLClientStatus status = judgeGetShareSecrets(returnShareSecrets);
|
||||
return status;
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus judgeGetShareSecrets(ReturnShareSecrets bufData) {
|
||||
int retcode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetShareSecrets**************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] the size of encrypted shares: " + bufData.encryptedSharesLength()));
|
||||
switch (retcode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] GetShareSecrets success"));
|
||||
returnShareList.clear();
|
||||
u2UClientList.clear();
|
||||
int length = bufData.encryptedSharesLength();
|
||||
for (int i = 0; i < length; i++) {
|
||||
EncryptShare shareSecret = new EncryptShare();
|
||||
shareSecret.setFlID(bufData.encryptedShares(i).flId());
|
||||
ByteBuffer bufShare = bufData.encryptedShares(i).shareAsByteBuffer();
|
||||
int sizeShare = bufData.encryptedShares(i).shareLength();
|
||||
shareSecret.setShare(byteToArray(bufShare, sizeShare));
|
||||
returnShareList.add(shareSecret);
|
||||
u2UClientList.add(bufData.encryptedShares(i).flId());
|
||||
}
|
||||
|
||||
return FLClientStatus.SUCCESS;
|
||||
case (ResponseCode.SucNotReady):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetShareSecrets again!"));
|
||||
return FLClientStatus.WAIT;
|
||||
case (ResponseCode.OutOfTime):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] GetShareSecrets out of time: need wait and request startFLJob again"));
|
||||
setNextRequestTime(bufData.nextReqTime());
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
case (ResponseCode.SystemError):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetShareSecrets"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReturnShareSecrets is invalid: " + retcode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus exchangeKeys() {
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ==================== round0: RequestExchangeKeys+GetExchangeKeys ======================"));
|
||||
FLClientStatus curStatus;
|
||||
// RequestExchangeKeys
|
||||
curStatus = requestExchangeKeys();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
Common.sleep(SLEEP_TIME);
|
||||
curStatus = requestExchangeKeys();
|
||||
}
|
||||
if (curStatus != FLClientStatus.SUCCESS) {
|
||||
return curStatus;
|
||||
}
|
||||
|
||||
// GetExchangeKeys
|
||||
curStatus = getExchangeKeys();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
Common.sleep(SLEEP_TIME);
|
||||
curStatus = getExchangeKeys();
|
||||
}
|
||||
return curStatus;
|
||||
}
|
||||
|
||||
public FLClientStatus shareSecrets() throws Exception {
|
||||
LOGGER.info(Common.addTag(("[PairWiseMask] ==================== round1: RequestShareSecrets+GetShareSecrets ======================")));
|
||||
FLClientStatus curStatus;
|
||||
// RequestShareSecrets
|
||||
curStatus = requestShareSecrets();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
Common.sleep(SLEEP_TIME);
|
||||
curStatus = requestShareSecrets();
|
||||
}
|
||||
if (curStatus != FLClientStatus.SUCCESS) {
|
||||
return curStatus;
|
||||
}
|
||||
|
||||
// GetShareSecrets
|
||||
curStatus = getShareSecrets();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
Common.sleep(SLEEP_TIME);
|
||||
curStatus = getShareSecrets();
|
||||
}
|
||||
return curStatus;
|
||||
}
|
||||
|
||||
public FLClientStatus reconstructSecrets() {
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] =================== round3: GetClientList+SendReconstructSecret ========================"));
|
||||
FLClientStatus curStatus;
|
||||
// GetClientList
|
||||
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList, cUVKeys);
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
Common.sleep(SLEEP_TIME);
|
||||
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList, cUVKeys);
|
||||
}
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
nextRequestTime = clientListReq.getNextRequestTime();
|
||||
}
|
||||
if (curStatus != FLClientStatus.SUCCESS) {
|
||||
return curStatus;
|
||||
}
|
||||
|
||||
// SendReconstructSecret
|
||||
curStatus = reconstructSecretReq.sendReconstructSecret(decryptShareSecretsList, u3ClientList, iteration);
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
Common.sleep(SLEEP_TIME);
|
||||
curStatus = reconstructSecretReq.sendReconstructSecret(decryptShareSecretsList, u3ClientList, iteration);
|
||||
}
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
nextRequestTime = reconstructSecretReq.getNextRequestTime();
|
||||
}
|
||||
return curStatus;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,125 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
public class Common {
|
||||
public static final String LOG_TITLE = "<FLClient> ";
|
||||
private static final Logger LOGGER = Logger.getLogger(Common.class.toString());
|
||||
private static List<String> flNameTrustList = new ArrayList<>(Arrays.asList("lenet", "adbert"));
|
||||
|
||||
public static String generateUrl(boolean useElb, String ip, int port, int serverNum) {
|
||||
String url;
|
||||
if (useElb) {
|
||||
Random rand = new Random();
|
||||
int randomNum = rand.nextInt(100000) % serverNum + port;
|
||||
url = ip + String.valueOf(randomNum);
|
||||
} else {
|
||||
url = ip + String.valueOf(port);
|
||||
}
|
||||
return url;
|
||||
}
|
||||
|
||||
public static void setClassifierWeightName(List<String> classifierWeightName) {
|
||||
classifierWeightName.add("albert.pooler.weight");
|
||||
classifierWeightName.add("albert.pooler.bias");
|
||||
classifierWeightName.add("classifier.weight");
|
||||
classifierWeightName.add("classifier.bias");
|
||||
LOGGER.info(addTag("classifierWeightName size: " + classifierWeightName.size()));
|
||||
}
|
||||
|
||||
public static void setAlbertWeightName(List<String> albertWeightName) {
|
||||
albertWeightName.add("albert.encoder.embedding_hidden_mapping_in.weight");
|
||||
albertWeightName.add("albert.encoder.embedding_hidden_mapping_in.bias");
|
||||
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.weight");
|
||||
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.bias");
|
||||
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.weight");
|
||||
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.bias");
|
||||
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.weight");
|
||||
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.bias");
|
||||
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.weight");
|
||||
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.bias");
|
||||
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.gamma");
|
||||
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.beta");
|
||||
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.weight");
|
||||
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.bias");
|
||||
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.weight");
|
||||
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.bias");
|
||||
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.gamma");
|
||||
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.beta");
|
||||
LOGGER.info(addTag("albertWeightName size: " + albertWeightName.size()));
|
||||
}
|
||||
|
||||
public static boolean checkFLName(String flName) {
|
||||
return (flNameTrustList.contains(flName));
|
||||
}
|
||||
|
||||
public static void sleep(long millis) {
|
||||
try {
|
||||
Thread.sleep(millis); //1000 milliseconds is one second.
|
||||
} catch (InterruptedException ex) {
|
||||
LOGGER.severe(addTag("[sleep] catch InterruptedException: " + ex.getMessage()));
|
||||
Thread.currentThread().interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
public static long getWaitTime(String nextRequestTime) {
|
||||
|
||||
Date date = new Date();
|
||||
long currentTime = date.getTime();
|
||||
long waitTime = 0;
|
||||
if (!nextRequestTime.equals("")) {
|
||||
waitTime = Math.max(0, Long.valueOf(nextRequestTime) - currentTime);
|
||||
}
|
||||
LOGGER.info(addTag("[getWaitTime] next request time stamp: " + nextRequestTime + " current time stamp: " + currentTime));
|
||||
LOGGER.info(addTag("[getWaitTime] waitTime: " + waitTime));
|
||||
return waitTime;
|
||||
}
|
||||
|
||||
public static long startTime(String tag) {
|
||||
Date startDate = new Date();
|
||||
long startTime = startDate.getTime();
|
||||
LOGGER.info(addTag("[start time] <" + tag + "> start time: " + startTime));
|
||||
return startTime;
|
||||
}
|
||||
|
||||
public static void endTime(long start, String tag) {
|
||||
Date endDate = new Date();
|
||||
long endTime = endDate.getTime();
|
||||
LOGGER.info(addTag("[end time] <" + tag + "> end time: " + endTime));
|
||||
LOGGER.info(addTag("[interval time] <" + tag + "> interval time(ms): " + (endTime - start)));
|
||||
}
|
||||
|
||||
public static String addTag(String message) {
|
||||
return LOG_TITLE + message;
|
||||
}
|
||||
|
||||
public static boolean isAutoscaling(byte[] message, String autoscalingTag) {
|
||||
return (new String(message)).contains(autoscalingTag);
|
||||
}
|
||||
|
||||
public static boolean checkPath(String path) {
|
||||
File file = new File(path);
|
||||
return file.exists();
|
||||
}
|
||||
}
|
|
@ -1,12 +1,12 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* <p>
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* <p>
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import okhttp3.Call;
|
||||
import okhttp3.Callback;
|
||||
import okhttp3.MediaType;
|
||||
import okhttp3.OkHttpClient;
|
||||
import okhttp3.Request;
|
||||
import okhttp3.RequestBody;
|
||||
import okhttp3.Response;
|
||||
|
||||
import javax.net.ssl.HostnameVerifier;
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLSession;
|
||||
import javax.net.ssl.TrustManager;
|
||||
import javax.net.ssl.X509TrustManager;
|
||||
import java.io.IOException;
|
||||
import java.security.KeyManagementException;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.security.cert.CertificateException;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.FLParameter.TIME_OUT;
|
||||
|
||||
public class FLCommunication implements IFLCommunication {
|
||||
private static int timeOut;
|
||||
private static boolean ssl = false;
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private static final MediaType MEDIA_TYPE_JSON = MediaType.parse("applicatiom/json;charset=utf-8");
|
||||
private static final Logger LOGGER = Logger.getLogger(FLCommunication.class.toString());
|
||||
private OkHttpClient client;
|
||||
|
||||
private static FLCommunication communication;
|
||||
|
||||
private FLCommunication() {
|
||||
if (flParameter.getTimeOut() != 0) {
|
||||
timeOut = flParameter.getTimeOut();
|
||||
} else {
|
||||
timeOut = TIME_OUT;
|
||||
}
|
||||
ssl = flParameter.isUseSSL();
|
||||
client = getUnsafeOkHttpClient();
|
||||
}
|
||||
|
||||
private static OkHttpClient getUnsafeOkHttpClient() {
|
||||
X509TrustManager trustManager = new X509TrustManager() {
|
||||
|
||||
@Override
|
||||
public X509Certificate[] getAcceptedIssuers() {
|
||||
return new X509Certificate[]{};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkServerTrusted(X509Certificate[] arg0, String arg1) throws CertificateException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkClientTrusted(X509Certificate[] arg0, String arg1) throws CertificateException {
|
||||
|
||||
}
|
||||
};
|
||||
final TrustManager[] trustAllCerts = new TrustManager[]{trustManager};
|
||||
try {
|
||||
LOGGER.info(Common.addTag("the set timeOut in OkHttpClient: " + timeOut));
|
||||
OkHttpClient.Builder builder = new OkHttpClient.Builder();
|
||||
builder.connectTimeout(timeOut, TimeUnit.SECONDS);
|
||||
builder.writeTimeout(timeOut, TimeUnit.SECONDS);
|
||||
builder.readTimeout(3 * timeOut, TimeUnit.SECONDS);
|
||||
if (ssl) {
|
||||
builder.sslSocketFactory(SSLSocketFactoryTools.getInstance().getmSslSocketFactory(), SSLSocketFactoryTools.getInstance().getmTrustManager());
|
||||
builder.hostnameVerifier(SSLSocketFactoryTools.getInstance().getHostnameVerifier());
|
||||
} else {
|
||||
final SSLContext sslContext = SSLContext.getInstance("TLS");
|
||||
sslContext.init(null, trustAllCerts, new java.security.SecureRandom());
|
||||
final javax.net.ssl.SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory();
|
||||
builder.sslSocketFactory(sslSocketFactory, trustManager);
|
||||
builder.hostnameVerifier(new HostnameVerifier() {
|
||||
@Override
|
||||
public boolean verify(String arg0, SSLSession arg1) {
|
||||
return true;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
} catch (NoSuchAlgorithmException | KeyManagementException e) {
|
||||
LOGGER.severe(Common.addTag("[OkHttpClient] catch NoSuchAlgorithmException or KeyManagementException: " + e.getMessage()));
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static FLCommunication getInstance() {
|
||||
if (communication == null) {
|
||||
synchronized (FLCommunication.class) {
|
||||
if (communication == null) {
|
||||
communication = new FLCommunication();
|
||||
}
|
||||
}
|
||||
}
|
||||
return communication;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setTimeOut(int timeout) throws TimeoutException {
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] syncRequest(String url, byte[] msg) throws IOException {
|
||||
Request request = new Request.Builder()
|
||||
.url(url)
|
||||
.post(RequestBody.create(MEDIA_TYPE_JSON, msg)).build();
|
||||
Response response = this.client.newCall(request).execute();
|
||||
if (!response.isSuccessful()) {
|
||||
throw new IOException("Unexpected code " + response);
|
||||
}
|
||||
return response.body().bytes();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void asyncRequest(String url, byte[] msg, IAsyncCallBack callBack) throws Exception {
|
||||
Request request = new Request.Builder()
|
||||
.url(url)
|
||||
.header("Accept", "application/proto")
|
||||
.header("Content-Type", "application/proto; charset=utf-8")
|
||||
.post(RequestBody.create(MEDIA_TYPE_JSON, msg)).build();
|
||||
|
||||
client.newCall(request).enqueue(new Callback() {
|
||||
IAsyncCallBack asyncCallBack = callBack;
|
||||
|
||||
@Override
|
||||
public void onResponse(Call call, Response response) throws IOException {
|
||||
asyncCallBack.onResponse(response.body().bytes());
|
||||
call.cancel();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Call call, IOException e) {
|
||||
asyncCallBack.onFailure(e);
|
||||
call.cancel();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
}
|
|
@ -1,12 +1,12 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* <p>
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
@ -17,13 +17,15 @@ package com.mindspore.flclient;
|
|||
|
||||
import java.util.logging.Logger;
|
||||
|
||||
public class FLJobResultCallback implements IFLJobResultCallback{
|
||||
private static final Logger logger = Logger.getLogger(FLJobResultCallback.class.toString());
|
||||
public class FLJobResultCallback implements IFLJobResultCallback {
|
||||
private static final Logger LOGGER = Logger.getLogger(FLJobResultCallback.class.toString());
|
||||
|
||||
public void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode) {
|
||||
logger.info(Common.addTag("[onFlJobIterationFinished] modelName: " + modelName + " iterationSeq: " + iterationSeq + " resultCode: " + resultCode));
|
||||
}
|
||||
public void onFlJobFinished(String modelName, int iterationCount, int resultCode) {
|
||||
logger.info(Common.addTag("[onFlJobFinished] modelName: " + modelName + " iterationCount: " + iterationCount + " resultCode: " + resultCode));
|
||||
LOGGER.info(Common.addTag("[onFlJobIterationFinished] modelName: " + modelName + " iterationSeq: " + iterationSeq + " resultCode: " + resultCode));
|
||||
}
|
||||
|
||||
}
|
||||
public void onFlJobFinished(String modelName, int iterationCount, int resultCode) {
|
||||
LOGGER.info(Common.addTag("[onFlJobFinished] modelName: " + modelName + " iterationCount: " + iterationCount + " resultCode: " + resultCode));
|
||||
}
|
||||
|
||||
}
|
|
@ -1,12 +1,12 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* <p>
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* <p>
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
@ -15,10 +15,18 @@
|
|||
*/
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import javax.net.ssl.*;
|
||||
import javax.net.ssl.HostnameVerifier;
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLSession;
|
||||
import javax.net.ssl.SSLSocketFactory;
|
||||
import javax.net.ssl.TrustManager;
|
||||
import javax.net.ssl.X509TrustManager;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.InputStream;
|
||||
import java.security.*;
|
||||
import java.security.InvalidKeyException;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.security.NoSuchProviderException;
|
||||
import java.security.SignatureException;
|
||||
import java.security.cert.CertificateException;
|
||||
import java.security.cert.CertificateFactory;
|
||||
import java.security.cert.X509Certificate;
|
||||
|
@ -32,11 +40,12 @@ public class SSLSocketFactoryTools {
|
|||
private SSLContext sslContext;
|
||||
private MyTrustManager myTrustManager;
|
||||
private static SSLSocketFactoryTools instance;
|
||||
|
||||
private SSLSocketFactoryTools() {
|
||||
initSslSocketFactory();
|
||||
}
|
||||
|
||||
private void initSslSocketFactory(){
|
||||
private void initSslSocketFactory() {
|
||||
try {
|
||||
sslContext = SSLContext.getInstance("TLS");
|
||||
x509Certificate = readCert(flParameter.getCertPath());
|
||||
|
@ -51,16 +60,14 @@ public class SSLSocketFactoryTools {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
public static SSLSocketFactoryTools getInstance() {
|
||||
if (instance == null) {
|
||||
instance=new SSLSocketFactoryTools();
|
||||
instance = new SSLSocketFactoryTools();
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
|
||||
|
||||
public X509Certificate readCert(String assetName) {
|
||||
public X509Certificate readCert(String assetName) {
|
||||
InputStream inputStream = null;
|
||||
try {
|
||||
inputStream = new FileInputStream(assetName);
|
||||
|
@ -110,7 +117,6 @@ public class SSLSocketFactoryTools {
|
|||
public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
|
||||
for (X509Certificate cert : chain) {
|
||||
|
@ -130,6 +136,7 @@ public class SSLSocketFactoryTools {
|
|||
} catch (SignatureException e) {
|
||||
logger.severe(Common.addTag("[SSLSocketFactoryTools] catch SignatureException in checkServerTrusted: " + e.getMessage()));
|
||||
}
|
||||
logger.info(Common.addTag("checkServerTrusted success!"));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,323 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import com.mindspore.flclient.model.AdTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import mindspore.schema.FeatureMap;
|
||||
|
||||
import java.security.SecureRandom;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
public class SecureProtocol {
|
||||
private static final Logger LOGGER = Logger.getLogger(SecureProtocol.class.toString());
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private int iteration;
|
||||
private CipherClient cipher;
|
||||
private FLClientStatus status;
|
||||
private float[] featureMask;
|
||||
private double dpEps;
|
||||
private double dpDelta;
|
||||
private double dpNormClip;
|
||||
private static double deltaError = 1e-6;
|
||||
private static Map<String, float[]> modelMap;
|
||||
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
|
||||
|
||||
public FLClientStatus getStatus() {
|
||||
return status;
|
||||
}
|
||||
|
||||
public float[] getFeatureMask() {
|
||||
return featureMask;
|
||||
}
|
||||
|
||||
public SecureProtocol() {
|
||||
}
|
||||
|
||||
public void setPWParameter(int iter, int minSecretNum, byte[] prime, int featureSize) {
|
||||
this.iteration = iter;
|
||||
this.cipher = new CipherClient(iteration, minSecretNum, prime, featureSize);
|
||||
}
|
||||
|
||||
public FLClientStatus setDPParameter(int iter, double diffEps,
|
||||
double diffDelta, double diffNorm, Map<String, float[]> map) {
|
||||
try {
|
||||
this.iteration = iter;
|
||||
this.dpEps = diffEps;
|
||||
this.dpDelta = diffDelta;
|
||||
this.dpNormClip = diffNorm;
|
||||
this.modelMap = map;
|
||||
status = FLClientStatus.SUCCESS;
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe(Common.addTag("[DPEncrypt] catch Exception in setDPParameter: " + e.getMessage()));
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
public ArrayList<String> getEncryptFeatureName() {
|
||||
return encryptFeatureName;
|
||||
}
|
||||
|
||||
public void setEncryptFeatureName(ArrayList<String> encryptFeatureName) {
|
||||
this.encryptFeatureName = encryptFeatureName;
|
||||
}
|
||||
|
||||
public String getNextRequestTime() {
|
||||
return cipher.getNextRequestTime();
|
||||
}
|
||||
|
||||
public FLClientStatus pwCreateMask() {
|
||||
LOGGER.info("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() + "==============");
|
||||
// round 0
|
||||
status = cipher.exchangeKeys();
|
||||
LOGGER.info("[PairWiseMask] ============= RequestExchangeKeys+GetExchangeKeys response: " + status + "============");
|
||||
if (status != FLClientStatus.SUCCESS) {
|
||||
return status;
|
||||
}
|
||||
// round 1
|
||||
try {
|
||||
status = cipher.shareSecrets();
|
||||
LOGGER.info("[Encrypt] =============RequestShareSecrets+GetShareSecrets response: " + status + "=============");
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask");
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
if (status != FLClientStatus.SUCCESS) {
|
||||
return status;
|
||||
}
|
||||
// round2
|
||||
try {
|
||||
featureMask = cipher.doubleMaskingWeight();
|
||||
LOGGER.info("[Encrypt] =============Create double feature mask: SUCCESS=============");
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask");
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize) {
|
||||
LOGGER.info("[Encrypt] feature mask size: " + featureMask.length);
|
||||
// get feature map
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
if (flParameter.getFlName().equals("adbert")) {
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
|
||||
} else if (flParameter.getFlName().equals("lenet")) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
}
|
||||
int featureSize = encryptFeatureName.size();
|
||||
int[] featuresMap = new int[featureSize];
|
||||
int maskIndex = 0;
|
||||
for (int i = 0; i < featureSize; i++) {
|
||||
String key = encryptFeatureName.get(i);
|
||||
float[] data = map.get(key);
|
||||
LOGGER.info("[Encrypt] feature name: " + key + " feature size: " + data.length);
|
||||
for (int j = 0; j < data.length; j++) {
|
||||
float rawData = data[j];
|
||||
float maskData = rawData * trainDataSize + featureMask[maskIndex];
|
||||
maskIndex += 1;
|
||||
data[j] = maskData;
|
||||
}
|
||||
int featureName = builder.createString(key);
|
||||
int weight = FeatureMap.createDataVector(builder, data);
|
||||
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
|
||||
featuresMap[i] = featureMap;
|
||||
}
|
||||
return featuresMap;
|
||||
}
|
||||
|
||||
public FLClientStatus pwUnmasking() {
|
||||
status = cipher.reconstructSecrets(); // round3
|
||||
LOGGER.info("[Encrypt] =============GetClientList+SendReconstructSecret: " + status + "=============");
|
||||
return status;
|
||||
}
|
||||
|
||||
private static float calculateErf(double x) {
|
||||
double result = 0;
|
||||
int segmentNum = 10000;
|
||||
double deltaX = x / segmentNum;
|
||||
result += 1;
|
||||
for (int i = 1; i < segmentNum; i++) {
|
||||
result += 2 * Math.exp(-Math.pow(deltaX * i, 2));
|
||||
}
|
||||
result += Math.exp(-Math.pow(deltaX * segmentNum, 2));
|
||||
return (float) (result * deltaX / Math.pow(Math.PI, 0.5));
|
||||
}
|
||||
|
||||
private static double calculatePhi(double t) {
|
||||
return 0.5 * (1.0 + calculateErf((t / Math.sqrt(2.0))));
|
||||
}
|
||||
|
||||
private static double calculateBPositive(double eps, double s) {
|
||||
return calculatePhi(Math.sqrt(eps * s)) - Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (s + 2.0)));
|
||||
}
|
||||
|
||||
private static double calculateBNegative(double eps, double s) {
|
||||
return calculatePhi(-Math.sqrt(eps * s)) - Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (s + 2.0)));
|
||||
}
|
||||
|
||||
private static double calculateSPositive(double eps, double targetDelta, double sInf, double sSup) {
|
||||
double deltaSup = calculateBPositive(eps, sSup);
|
||||
while (deltaSup <= targetDelta) {
|
||||
sInf = sSup;
|
||||
sSup = 2 * sInf;
|
||||
deltaSup = calculateBPositive(eps, sSup);
|
||||
}
|
||||
|
||||
double sMid = sInf + (sSup - sInf) / 2.0;
|
||||
int iterMax = 1000;
|
||||
int iters = 0;
|
||||
while (true) {
|
||||
double b = calculateBPositive(eps, sMid);
|
||||
if (b <= targetDelta) {
|
||||
if (targetDelta - b <= deltaError) {
|
||||
break;
|
||||
} else {
|
||||
sInf = sMid;
|
||||
}
|
||||
} else {
|
||||
sSup = sMid;
|
||||
}
|
||||
sMid = sInf + (sSup - sInf) / 2.0;
|
||||
iters += 1;
|
||||
if (iters > iterMax) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return sMid;
|
||||
}
|
||||
|
||||
private static double calculateSNegative(double eps, double targetDelta, double sInf, double sSup) {
|
||||
double deltaSup = calculateBNegative(eps, sSup);
|
||||
while (deltaSup > targetDelta) {
|
||||
sInf = sSup;
|
||||
sSup = 2 * sInf;
|
||||
deltaSup = calculateBNegative(eps, sSup);
|
||||
}
|
||||
|
||||
double sMid = sInf + (sSup - sInf) / 2.0;
|
||||
int iterMax = 1000;
|
||||
int iters = 0;
|
||||
while (true) {
|
||||
double b = calculateBNegative(eps, sMid);
|
||||
if (b <= targetDelta) {
|
||||
if (targetDelta - b <= deltaError) {
|
||||
break;
|
||||
} else {
|
||||
sSup = sMid;
|
||||
}
|
||||
} else {
|
||||
sInf = sMid;
|
||||
}
|
||||
sMid = sInf + (sSup - sInf) / 2.0;
|
||||
iters += 1;
|
||||
if (iters > iterMax) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return sMid;
|
||||
}
|
||||
|
||||
private static double calculateSigma(double clipNorm, double eps, double targetDelta) {
|
||||
double deltaZero = calculateBPositive(eps, 0);
|
||||
double alpha = 1;
|
||||
if (targetDelta > deltaZero) {
|
||||
double s = calculateSPositive(eps, targetDelta, 0, 1);
|
||||
alpha = Math.sqrt(1.0 + s / 2.0) - Math.sqrt(s / 2.0);
|
||||
} else if (targetDelta < deltaZero) {
|
||||
double s = calculateSNegative(eps, targetDelta, 0, 1);
|
||||
alpha = Math.sqrt(1.0 + s / 2.0) + Math.sqrt(s / 2.0);
|
||||
}
|
||||
return alpha * clipNorm / Math.sqrt(2.0 * eps);
|
||||
}
|
||||
|
||||
public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize) {
|
||||
// get feature map
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
if (flParameter.getFlName().equals("adbert")) {
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
|
||||
} else if (flParameter.getFlName().equals("lenet")) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
}
|
||||
Map<String, float[]> mapBeforeTrain = modelMap;
|
||||
int featureSize = encryptFeatureName.size();
|
||||
// calculate sigma
|
||||
double gaussianSigma = calculateSigma(dpNormClip, dpEps, dpDelta);
|
||||
LOGGER.info(Common.addTag("[Encrypt] =============Noise sigma of DP is: " + gaussianSigma + "============="));
|
||||
|
||||
// prepare gaussian noise
|
||||
SecureRandom random = new SecureRandom();
|
||||
int randomInt = random.nextInt();
|
||||
Random r = new Random(randomInt);
|
||||
|
||||
// calculate l2-norm of all layers' update array
|
||||
double updateL2Norm = 0;
|
||||
for (int i = 0; i < featureSize; i++) {
|
||||
String key = encryptFeatureName.get(i);
|
||||
float[] data = map.get(key);
|
||||
float[] dataBeforeTrain = mapBeforeTrain.get(key);
|
||||
for (int j = 0; j < data.length; j++) {
|
||||
float rawData = data[j];
|
||||
float rawDataBeforeTrain = dataBeforeTrain[j];
|
||||
float updateData = rawData - rawDataBeforeTrain;
|
||||
updateL2Norm += updateData * updateData;
|
||||
}
|
||||
}
|
||||
updateL2Norm = Math.sqrt(updateL2Norm);
|
||||
double clipFactor = Math.min(1.0, dpNormClip / updateL2Norm);
|
||||
|
||||
// clip and add noise
|
||||
int[] featuresMap = new int[featureSize];
|
||||
for (int i = 0; i < featureSize; i++) {
|
||||
String key = encryptFeatureName.get(i);
|
||||
float[] data = map.get(key);
|
||||
float[] data2 = new float[data.length];
|
||||
float[] dataBeforeTrain = mapBeforeTrain.get(key);
|
||||
for (int j = 0; j < data.length; j++) {
|
||||
float rawData = data[j];
|
||||
float rawDataBeforeTrain = dataBeforeTrain[j];
|
||||
float updateData = rawData - rawDataBeforeTrain;
|
||||
|
||||
// clip
|
||||
updateData *= clipFactor;
|
||||
|
||||
// add noise
|
||||
double gaussianNoise = r.nextGaussian() * gaussianSigma;
|
||||
updateData += gaussianNoise;
|
||||
data2[j] = rawDataBeforeTrain + updateData;
|
||||
data2[j] = data2[j] * trainDataSize;
|
||||
}
|
||||
int featureName = builder.createString(key);
|
||||
int weight = FeatureMap.createDataVector(builder, data2);
|
||||
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
|
||||
featuresMap[i] = featureMap;
|
||||
}
|
||||
return featuresMap;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,12 +1,12 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* <p>
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.cipher;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
|
||||
import javax.crypto.Cipher;
|
||||
import javax.crypto.spec.IvParameterSpec;
|
||||
import javax.crypto.spec.SecretKeySpec;
|
||||
import java.io.UnsupportedEncodingException;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
public class AESEncrypt {
|
||||
private static final Logger LOGGER = Logger.getLogger(AESEncrypt.class.toString());
|
||||
/**
|
||||
* 128, 192 or 256
|
||||
*/
|
||||
private static final int KEY_SIZE = 256;
|
||||
private static final int I_VEC_LEN = 16;
|
||||
|
||||
/**
|
||||
* encrypt/decrypt algorithm name
|
||||
*/
|
||||
private static final String ALGORITHM = "AES";
|
||||
|
||||
/**
|
||||
* algorithm/Mode/padding mode
|
||||
*/
|
||||
private static final String CIPHER_MODE_CTR = "AES/CTR/NoPadding";
|
||||
private static final String CIPHER_MODE_CBC = "AES/CBC/PKCS5PADDING";
|
||||
|
||||
private String CIPHER_MODE;
|
||||
|
||||
private static final int RANDOM_LEN = KEY_SIZE / 8;
|
||||
|
||||
private String iVecS = "1111111111111111";
|
||||
private byte[] iVec = iVecS.getBytes("utf-8");
|
||||
|
||||
public AESEncrypt(byte[] key, byte[] iVecIn, String mode) throws UnsupportedEncodingException {
|
||||
if (key == null) {
|
||||
LOGGER.severe(Common.addTag("Key is null"));
|
||||
return;
|
||||
}
|
||||
if (key.length != KEY_SIZE / 8) {
|
||||
LOGGER.severe(Common.addTag("the length of key is not correct"));
|
||||
return;
|
||||
}
|
||||
|
||||
if (mode.contains("CBC")) {
|
||||
CIPHER_MODE = CIPHER_MODE_CBC;
|
||||
} else if (mode.contains("CTR")) {
|
||||
CIPHER_MODE = CIPHER_MODE_CTR;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
if (iVecIn == null || iVecIn.length != I_VEC_LEN) {
|
||||
return;
|
||||
}
|
||||
iVec = iVecIn;
|
||||
}
|
||||
|
||||
public byte[] encrypt(byte[] key, byte[] data) throws Exception {
|
||||
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
|
||||
Cipher cipher = Cipher.getInstance(CIPHER_MODE);
|
||||
IvParameterSpec iv = new IvParameterSpec(iVec);
|
||||
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv);
|
||||
byte[] encrypted = cipher.doFinal(data);
|
||||
String encryptResultStr = BaseUtil.byte2HexString(encrypted);
|
||||
return encrypted;
|
||||
}
|
||||
|
||||
public byte[] encryptCTR(byte[] key, byte[] data) throws Exception {
|
||||
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
|
||||
Cipher cipher = Cipher.getInstance(CIPHER_MODE);
|
||||
IvParameterSpec iv = new IvParameterSpec(iVec);
|
||||
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv);
|
||||
byte[] encrypted = cipher.doFinal(data);
|
||||
return encrypted;
|
||||
}
|
||||
|
||||
public byte[] decrypt(byte[] key, byte[] encryptData) throws Exception {
|
||||
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
|
||||
Cipher cipher = Cipher.getInstance(CIPHER_MODE);
|
||||
IvParameterSpec iv = new IvParameterSpec(iVec);
|
||||
cipher.init(Cipher.DECRYPT_MODE, skeySpec, iv);
|
||||
byte[] origin = cipher.doFinal(encryptData);
|
||||
return origin;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,143 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.mindspore.flclient.cipher;
|
||||
|
||||
import java.io.UnsupportedEncodingException;
|
||||
import java.math.BigInteger;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class BaseUtil {
|
||||
private static final char[] HEX_DIGITS = new char[]{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'};
|
||||
|
||||
public BaseUtil() {
|
||||
}
|
||||
|
||||
public static String byte2HexString(byte[] bytes) {
|
||||
if (null == bytes) {
|
||||
return null;
|
||||
} else if (bytes.length == 0) {
|
||||
return "";
|
||||
} else {
|
||||
char[] chars = new char[bytes.length * 2];
|
||||
|
||||
for (int i = 0; i < bytes.length; ++i) {
|
||||
int b = bytes[i];
|
||||
chars[i * 2] = HEX_DIGITS[(b & 240) >> 4];
|
||||
chars[i * 2 + 1] = HEX_DIGITS[b & 15];
|
||||
}
|
||||
return new String(chars);
|
||||
}
|
||||
}
|
||||
|
||||
public static byte[] hexString2ByteArray(String str) {
|
||||
int length = str.length() / 2;
|
||||
byte[] bytes = new byte[length];
|
||||
byte[] source = str.getBytes(Charset.forName("UTF-8"));
|
||||
|
||||
for (int i = 0; i < bytes.length; ++i) {
|
||||
byte bh = Byte.decode("0x" + new String(new byte[]{source[i * 2]}, Charset.forName("UTF-8")));
|
||||
bh = (byte) (bh << 4);
|
||||
byte bl = Byte.decode("0x" + new String(new byte[]{source[i * 2 + 1]}, Charset.forName("UTF-8")));
|
||||
bytes[i] = (byte) (bh ^ bl);
|
||||
}
|
||||
return bytes;
|
||||
}
|
||||
|
||||
public static BigInteger byteArray2BigInteger(byte[] bytes) {
|
||||
|
||||
BigInteger bigInteger = BigInteger.ZERO;
|
||||
for (int i = 0; i < bytes.length; ++i) {
|
||||
int intI = bytes[i];
|
||||
if (intI < 0) {
|
||||
intI = intI + 256;
|
||||
}
|
||||
BigInteger bi = new BigInteger(String.valueOf(intI));
|
||||
bigInteger = bigInteger.multiply(BigInteger.valueOf(256)).add(bi);
|
||||
}
|
||||
return bigInteger;
|
||||
}
|
||||
|
||||
public static BigInteger string2BigInteger(String str) throws UnsupportedEncodingException {
|
||||
StringBuilder res = new StringBuilder();
|
||||
byte[] bytes = String.valueOf(str).getBytes("UTF-8");
|
||||
BigInteger bigInteger = BigInteger.ZERO;
|
||||
for (int i = 0; i < str.length(); ++i) {
|
||||
BigInteger bi = new BigInteger(String.valueOf(bytes[i]));
|
||||
bigInteger = bigInteger.multiply(BigInteger.valueOf(256)).add(bi);
|
||||
}
|
||||
return bigInteger;
|
||||
}
|
||||
|
||||
public static String bigInteger2String(BigInteger bigInteger) throws UnsupportedEncodingException {
|
||||
StringBuilder res = new StringBuilder();
|
||||
List<Integer> lists = new ArrayList<>();
|
||||
BigInteger bi = bigInteger;
|
||||
BigInteger DIV = BigInteger.valueOf(256);
|
||||
while (bi.compareTo(BigInteger.ZERO) > 0) {
|
||||
lists.add(bi.mod(DIV).intValue());
|
||||
bi = bi.divide(DIV);
|
||||
}
|
||||
for (int i = lists.size() - 1; i >= 0; --i) {
|
||||
res.append((char) (int) (lists.get(i)));
|
||||
}
|
||||
return res.toString();
|
||||
}
|
||||
|
||||
public static byte[] bigInteger2byteArray(BigInteger bigInteger) throws UnsupportedEncodingException {
|
||||
List<Integer> lists = new ArrayList<>();
|
||||
BigInteger bi = bigInteger;
|
||||
BigInteger DIV = BigInteger.valueOf(256);
|
||||
while (bi.compareTo(BigInteger.ZERO) > 0) {
|
||||
lists.add(bi.mod(DIV).intValue());
|
||||
bi = bi.divide(DIV);
|
||||
}
|
||||
byte[] res = new byte[lists.size()];
|
||||
for (int i = lists.size() - 1; i >= 0; --i) {
|
||||
res[lists.size() - i - 1] = ((byte) (int) (lists.get(i)));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
public static byte[] integer2byteArray(Integer num) {
|
||||
List<Integer> lists = new ArrayList<>();
|
||||
Integer bi = num;
|
||||
Integer DIV = 256;
|
||||
while (bi > 0) {
|
||||
lists.add(bi % DIV);
|
||||
bi = bi / DIV;
|
||||
}
|
||||
byte[] res = new byte[lists.size()];
|
||||
for (int i = lists.size() - 1; i >= 0; --i) {
|
||||
res[lists.size() - i - 1] = ((byte) (int) (lists.get(i)));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
public static Integer byteArray2Integer(byte[] bytes) {
|
||||
|
||||
Integer num = 0;
|
||||
for (int i = 0; i < bytes.length; ++i) {
|
||||
int intI = bytes[i];
|
||||
if (intI < 0) {
|
||||
intI = intI + 256;
|
||||
}
|
||||
num = num * 256 + intI;
|
||||
}
|
||||
return num;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,158 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.cipher;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import com.mindspore.flclient.Common;
|
||||
import com.mindspore.flclient.FLClientStatus;
|
||||
import com.mindspore.flclient.FLCommunication;
|
||||
import com.mindspore.flclient.FLParameter;
|
||||
import com.mindspore.flclient.LocalFLParameter;
|
||||
import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
|
||||
import com.mindspore.flclient.cipher.struct.EncryptShare;
|
||||
import com.mindspore.flclient.cipher.struct.NewArray;
|
||||
import mindspore.schema.GetClientList;
|
||||
import mindspore.schema.ResponseCode;
|
||||
import mindspore.schema.ReturnClientList;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.IVEC_LEN;
|
||||
|
||||
public class ClientListReq {
|
||||
|
||||
private static final Logger LOGGER = Logger.getLogger(ClientListReq.class.toString());
|
||||
private FLCommunication flCommunication;
|
||||
private String nextRequestTime;
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
|
||||
public ClientListReq() {
|
||||
flCommunication = FLCommunication.getInstance();
|
||||
}
|
||||
|
||||
public String getNextRequestTime() {
|
||||
return nextRequestTime;
|
||||
}
|
||||
|
||||
public void setNextRequestTime(String nextRequestTime) {
|
||||
this.nextRequestTime = nextRequestTime;
|
||||
}
|
||||
|
||||
public FLClientStatus getClientList(int iteration, List<String> u3ClientList, List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ==============getClientList url: " + url + "=============="));
|
||||
FlatBufferBuilder builder = new FlatBufferBuilder();
|
||||
int id = builder.createString(localFLParameter.getFlID());
|
||||
String dateTime = LocalDateTime.now().toString();
|
||||
int time = builder.createString(dateTime);
|
||||
int clientListRoot = GetClientList.createGetClientList(builder, id, iteration, time);
|
||||
builder.finish(clientListRoot);
|
||||
byte[] msg = builder.sizedByteArray();
|
||||
try {
|
||||
byte[] responseData = flCommunication.syncRequest(url + "/getClientList", msg);
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
ReturnClientList clientListRsp = ReturnClientList.getRootAsReturnClientList(buffer);
|
||||
FLClientStatus status = judgeGetClientList(clientListRsp, u3ClientList, decryptSecretsList, returnShareList, cuvKeys);
|
||||
return status;
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus judgeGetClientList(ReturnClientList bufData, List<String> u3ClientList, List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
|
||||
int retcode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ************** the response of GetClientList **************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] the size of clients: " + bufData.clientsLength()));
|
||||
FLClientStatus status;
|
||||
switch (retcode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] GetClientList success"));
|
||||
u3ClientList.clear();
|
||||
int clientSize = bufData.clientsLength();
|
||||
for (int i = 0; i < clientSize; i++) {
|
||||
String curFlId = bufData.clients(i);
|
||||
u3ClientList.add(curFlId);
|
||||
}
|
||||
try {
|
||||
decryptSecretShares(decryptSecretsList, returnShareList, cuvKeys);
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
case (ResponseCode.SucNotReady):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetClientList again!"));
|
||||
return FLClientStatus.WAIT;
|
||||
case (ResponseCode.OutOfTime):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] GetClientList out of time: need wait and request startFLJob again"));
|
||||
setNextRequestTime(bufData.nextReqTime());
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
case (ResponseCode.SystemError):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetClientList"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReturnClientList is invalid: " + retcode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
public void decryptSecretShares(List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) throws Exception {
|
||||
decryptSecretsList.clear();
|
||||
int size = returnShareList.size();
|
||||
for (int i = 0; i < size; i++) {
|
||||
DecryptShareSecrets decryptShareSecrets = new DecryptShareSecrets();
|
||||
EncryptShare encryptShare = returnShareList.get(i);
|
||||
String vFlID = encryptShare.getFlID();
|
||||
byte[] share = encryptShare.getShare().getArray();
|
||||
byte[] iVecIn = new byte[IVEC_LEN];
|
||||
AESEncrypt aesEncrypt = new AESEncrypt(cuvKeys.get(vFlID), iVecIn, "CBC");
|
||||
byte[] decryptShare = aesEncrypt.decrypt(cuvKeys.get(vFlID), share);
|
||||
int sSize = (int) decryptShare[0];
|
||||
int bSize = (int) decryptShare[1];
|
||||
int sIndexLen = (int) decryptShare[2];
|
||||
int bIndexLen = (int) decryptShare[3];
|
||||
int sIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4, 4 + sIndexLen));
|
||||
int bIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4 + sIndexLen, 4 + sIndexLen + bIndexLen));
|
||||
byte[] sSkUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen, 4 + sIndexLen + bIndexLen + sSize);
|
||||
byte[] bUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen + sSize, 4 + sIndexLen + bIndexLen + sSize + bSize);
|
||||
NewArray<byte[]> sSkVu = new NewArray<>();
|
||||
sSkVu.setSize(sSize);
|
||||
sSkVu.setArray(sSkUv);
|
||||
NewArray bVu = new NewArray();
|
||||
bVu.setSize(bSize);
|
||||
bVu.setArray(bUv);
|
||||
decryptShareSecrets.setFlID(vFlID);
|
||||
decryptShareSecrets.setSSkVu(sSkVu);
|
||||
decryptShareSecrets.setBVu(bVu);
|
||||
decryptShareSecrets.setSIndex(sIndex);
|
||||
decryptShareSecrets.setIndexB(bIndex);
|
||||
decryptSecretsList.add(decryptShareSecrets);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.cipher;
|
||||
|
||||
import org.bouncycastle.crypto.digests.SHA256Digest;
|
||||
import org.bouncycastle.crypto.generators.PKCS5S2ParametersGenerator;
|
||||
import org.bouncycastle.crypto.params.KeyParameter;
|
||||
import org.bouncycastle.math.ec.rfc7748.X25519;
|
||||
|
||||
import java.security.SecureRandom;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
public class KEYAgreement {
|
||||
private static final Logger LOGGER = Logger.getLogger(KEYAgreement.class.toString());
|
||||
private static final int PBKDF2_ITERATIONS = 10000;
|
||||
private static final int SALT_SIZE = 32;
|
||||
private static final int HASH_BIT_SIZE = 256;
|
||||
private static final int KEY_LEN = X25519.SCALAR_SIZE;
|
||||
|
||||
private SecureRandom random = new SecureRandom();
|
||||
|
||||
public byte[] generatePrivateKey() {
|
||||
byte[] privateKey = new byte[KEY_LEN];
|
||||
X25519.generatePrivateKey(random, privateKey);
|
||||
return privateKey;
|
||||
}
|
||||
|
||||
public byte[] generatePublicKey(byte[] privatekey) {
|
||||
byte[] publicKey = new byte[KEY_LEN];
|
||||
X25519.generatePublicKey(privatekey, 0, publicKey, 0);
|
||||
return publicKey;
|
||||
}
|
||||
|
||||
public byte[] keyAgreement(byte[] privatekey, byte[] publicKey) {
|
||||
byte[] secret = new byte[KEY_LEN];
|
||||
X25519.calculateAgreement(privatekey, 0, publicKey, 0, secret, 0);
|
||||
return secret;
|
||||
}
|
||||
|
||||
public byte[] getEncryptedPassword(byte[] password, byte[] salt) {
|
||||
|
||||
byte[] saltB = new byte[SALT_SIZE];
|
||||
PKCS5S2ParametersGenerator gen = new PKCS5S2ParametersGenerator(new SHA256Digest());
|
||||
gen.init(password, saltB, PBKDF2_ITERATIONS);
|
||||
byte[] dk = ((KeyParameter) gen.generateDerivedParameters(HASH_BIT_SIZE)).getKey();
|
||||
return dk;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.cipher;
|
||||
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.security.SecureRandom;
|
||||
import java.util.List;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
public class Random {
|
||||
/**
|
||||
* random generate RNG algorithm name
|
||||
*/
|
||||
private static final Logger LOGGER = Logger.getLogger(Random.class.toString());
|
||||
private static final String RNG_ALGORITHM = "SHA1PRNG";
|
||||
|
||||
private static final int RANDOM_LEN = 128 / 8;
|
||||
|
||||
public void getRandomBytes(byte[] secret) {
|
||||
try {
|
||||
SecureRandom secureRandom = SecureRandom.getInstance("SHA1PRNG");
|
||||
secureRandom.nextBytes(secret);
|
||||
} catch (NoSuchAlgorithmException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
public void randomAESCTR(List<Float> noise, int length, byte[] seed) throws Exception {
|
||||
int intV = Integer.SIZE / 8;
|
||||
int size = length * intV;
|
||||
byte[] data = new byte[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
data[i] = 0;
|
||||
}
|
||||
byte[] ivec = new byte[RANDOM_LEN];
|
||||
AESEncrypt aesEncrypt = new AESEncrypt(seed, ivec, "CTR");
|
||||
byte[] encryptCtr = aesEncrypt.encryptCTR(seed, data);
|
||||
for (int i = 0; i < length; i++) {
|
||||
int[] sub = new int[intV];
|
||||
for (int j = 0; j < 4; j++) {
|
||||
sub[j] = (int) encryptCtr[i * intV + j] & 0xff;
|
||||
}
|
||||
int subI = byte2int(sub, 4);
|
||||
|
||||
Float f = Float.valueOf(Float.valueOf(subI) / Integer.MAX_VALUE);
|
||||
noise.add(f);
|
||||
}
|
||||
}
|
||||
|
||||
public static int byte2int(int[] data, int n) {
|
||||
switch (n) {
|
||||
case 1:
|
||||
return (int) data[0];
|
||||
case 2:
|
||||
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00);
|
||||
case 3:
|
||||
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000);
|
||||
case 4:
|
||||
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000)
|
||||
| (data[3] << 24 & 0xff000000);
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.cipher;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import com.mindspore.flclient.Common;
|
||||
import com.mindspore.flclient.FLClientStatus;
|
||||
import com.mindspore.flclient.FLCommunication;
|
||||
import com.mindspore.flclient.FLParameter;
|
||||
import com.mindspore.flclient.LocalFLParameter;
|
||||
import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
|
||||
import mindspore.schema.ClientShare;
|
||||
import mindspore.schema.ResponseCode;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.List;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
public class ReconstructSecretReq {
|
||||
private static final Logger LOGGER = Logger.getLogger(ReconstructSecretReq.class.toString());
|
||||
private FLCommunication flCommunication;
|
||||
private String nextRequestTime;
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
|
||||
public String getNextRequestTime() {
|
||||
return nextRequestTime;
|
||||
}
|
||||
|
||||
public void setNextRequestTime(String nextRequestTime) {
|
||||
this.nextRequestTime = nextRequestTime;
|
||||
}
|
||||
|
||||
public ReconstructSecretReq() {
|
||||
flCommunication = FLCommunication.getInstance();
|
||||
}
|
||||
|
||||
public FLClientStatus sendReconstructSecret(List<DecryptShareSecrets> decryptShareSecretsList, List<String> u3ClientList, int iteration) {
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ==============sendReconstructSecret url: " + url + "=============="));
|
||||
FlatBufferBuilder builder = new FlatBufferBuilder();
|
||||
int desFlId = builder.createString(localFLParameter.getFlID());
|
||||
String dateTime = LocalDateTime.now().toString();
|
||||
int time = builder.createString(dateTime);
|
||||
int shareSecretsSize = decryptShareSecretsList.size();
|
||||
if (shareSecretsSize <= 0) {
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] request failed: the decryptShareSecretsList is null, please waite."));
|
||||
return FLClientStatus.FAILED;
|
||||
} else {
|
||||
int[] decryptShareList = new int[shareSecretsSize];
|
||||
for (int i = 0; i < shareSecretsSize; i++) {
|
||||
DecryptShareSecrets decryptShareSecrets = decryptShareSecretsList.get(i);
|
||||
String srcFlId = decryptShareSecrets.getFlID();
|
||||
byte[] share;
|
||||
int index;
|
||||
if (u3ClientList.contains(srcFlId)) {
|
||||
share = decryptShareSecrets.getBVu().getArray();
|
||||
index = decryptShareSecrets.getIndexB();
|
||||
} else {
|
||||
share = decryptShareSecrets.getSSkVu().getArray();
|
||||
index = decryptShareSecrets.getSIndex();
|
||||
}
|
||||
int fbsSrcFlId = builder.createString(srcFlId);
|
||||
int fbsShare = ClientShare.createShareVector(builder, share);
|
||||
int clientShare = ClientShare.createClientShare(builder, fbsSrcFlId, fbsShare, index);
|
||||
decryptShareList[i] = clientShare;
|
||||
}
|
||||
int reconstructShareSecrets = mindspore.schema.SendReconstructSecret.createReconstructSecretSharesVector(builder, decryptShareList);
|
||||
int reconstructSecretRoot = mindspore.schema.SendReconstructSecret.createSendReconstructSecret(builder, desFlId, reconstructShareSecrets, iteration, time);
|
||||
builder.finish(reconstructSecretRoot);
|
||||
byte[] msg = builder.sizedByteArray();
|
||||
try {
|
||||
byte[] responseData = flCommunication.syncRequest(url + "/reconstructSecrets", msg);
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
mindspore.schema.ReconstructSecret reconstructSecretRsp = mindspore.schema.ReconstructSecret.getRootAsReconstructSecret(buffer);
|
||||
FLClientStatus status = judgeSendReconstructSecrets(reconstructSecretRsp);
|
||||
return status;
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] un solved error code in reconstruct"));
|
||||
e.printStackTrace();
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus judgeSendReconstructSecrets(mindspore.schema.ReconstructSecret bufData) {
|
||||
int retcode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of SendReconstructSecrets**************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
|
||||
switch (retcode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ReconstructSecrets success"));
|
||||
return FLClientStatus.SUCCESS;
|
||||
case (ResponseCode.OutOfTime):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] SendReconstructSecrets out of time: need wait and request startFLJob again"));
|
||||
setNextRequestTime(bufData.nextReqTime());
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
case (ResponseCode.SystemError):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in SendReconstructSecrets"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReconstructSecret is invalid: " + retcode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,136 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.cipher;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.Random;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
|
||||
public class ShareSecrets {
|
||||
private static final Logger LOGGER = Logger.getLogger(ShareSecrets.class.toString());
|
||||
|
||||
public final class SecretShare {
|
||||
public SecretShare(final int num, final BigInteger share) {
|
||||
this.num = num;
|
||||
this.share = share;
|
||||
}
|
||||
|
||||
public int getNum() {
|
||||
return num;
|
||||
}
|
||||
|
||||
public BigInteger getShare() {
|
||||
return share;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "SecretShare [num=" + num + ", share=" + share + "]";
|
||||
}
|
||||
|
||||
private final int num;
|
||||
private final BigInteger share;
|
||||
}
|
||||
|
||||
public ShareSecrets(final int k, final int n) {
|
||||
this.k = k;
|
||||
this.n = n;
|
||||
|
||||
random = new Random();
|
||||
}
|
||||
|
||||
public SecretShare[] split(final byte[] bytes, byte[] primeByte) {
|
||||
BigInteger secret = BaseUtil.byteArray2BigInteger(bytes);
|
||||
final int modLength = secret.bitLength() + 1;
|
||||
prime = BaseUtil.byteArray2BigInteger(primeByte);
|
||||
final BigInteger[] coeff = new BigInteger[k - 1];
|
||||
|
||||
LOGGER.info(Common.addTag("Prime Number: " + prime));
|
||||
|
||||
for (int i = 0; i < k - 1; i++) {
|
||||
coeff[i] = randomZp(prime);
|
||||
LOGGER.info(Common.addTag("a" + (i + 1) + ": " + coeff[i]));
|
||||
}
|
||||
|
||||
final SecretShare[] shares = new SecretShare[n];
|
||||
for (int i = 1; i <= n; i++) {
|
||||
BigInteger accum = secret;
|
||||
|
||||
for (int j = 1; j < k; j++) {
|
||||
final BigInteger t1 = BigInteger.valueOf(i).modPow(BigInteger.valueOf(j), prime);
|
||||
final BigInteger t2 = coeff[j - 1].multiply(t1).mod(prime);
|
||||
|
||||
accum = accum.add(t2).mod(prime);
|
||||
}
|
||||
shares[i - 1] = new SecretShare(i, accum);
|
||||
LOGGER.info(Common.addTag("Share " + shares[i - 1]));
|
||||
}
|
||||
return shares;
|
||||
}
|
||||
|
||||
public BigInteger getPrime() {
|
||||
return prime;
|
||||
}
|
||||
|
||||
public BigInteger combine(final SecretShare[] shares, final byte[] primeByte) {
|
||||
BigInteger primeNum = BaseUtil.byteArray2BigInteger(primeByte);
|
||||
BigInteger accum = BigInteger.ZERO;
|
||||
for (int j = 0; j < k; j++) {
|
||||
BigInteger num = BigInteger.ONE;
|
||||
BigInteger den = BigInteger.ONE;
|
||||
|
||||
BigInteger tmp;
|
||||
|
||||
for (int m = 0; m < k; m++) {
|
||||
if (j != m) {
|
||||
num = num.multiply(BigInteger.valueOf(shares[m].getNum())).mod(primeNum);
|
||||
tmp = BigInteger.valueOf(shares[j].getNum()).multiply(BigInteger.valueOf(-1));
|
||||
tmp = BigInteger.valueOf(shares[m].getNum()).add(tmp).mod(primeNum);
|
||||
den = den.multiply(tmp).mod(primeNum);
|
||||
}
|
||||
}
|
||||
final BigInteger value = shares[j].getShare();
|
||||
|
||||
tmp = den.modInverse(primeNum);
|
||||
tmp = tmp.multiply(num).mod(primeNum);
|
||||
tmp = tmp.multiply(value).mod(primeNum);
|
||||
accum = accum.add(tmp).mod(primeNum);
|
||||
LOGGER.info(Common.addTag("value: " + value + ", tmp: " + tmp + ", accum: " + accum));
|
||||
}
|
||||
LOGGER.info(Common.addTag("The secret is: " + accum));
|
||||
return accum;
|
||||
}
|
||||
|
||||
private BigInteger randomZp(final BigInteger p) {
|
||||
while (true) {
|
||||
final BigInteger r = new BigInteger(p.bitLength(), random);
|
||||
if (r.compareTo(BigInteger.ZERO) > 0 && r.compareTo(p) < 0) {
|
||||
return r;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private BigInteger prime;
|
||||
private final int k;
|
||||
private final int n;
|
||||
private final Random random;
|
||||
private final int SECRET_MAX_LEN = 32;
|
||||
}
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.cipher.struct;
|
||||
|
||||
public class ClientPublicKey {
|
||||
private String flID;
|
||||
private NewArray<byte[]> cPK;
|
||||
private NewArray<byte[]> sPk;
|
||||
|
||||
public String getFlID() {
|
||||
return flID;
|
||||
}
|
||||
|
||||
public void setFlID(String flID) {
|
||||
this.flID = flID;
|
||||
}
|
||||
|
||||
public NewArray<byte[]> getCPK() {
|
||||
return cPK;
|
||||
}
|
||||
|
||||
public void setCPK(NewArray<byte[]> cPK) {
|
||||
this.cPK = cPK;
|
||||
}
|
||||
|
||||
public NewArray<byte[]> getSPK() {
|
||||
return sPk;
|
||||
}
|
||||
|
||||
public void setSPK(NewArray<byte[]> sPk) {
|
||||
this.sPk = sPk;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.cipher.struct;
|
||||
|
||||
public class DecryptShareSecrets {
|
||||
private String flID;
|
||||
private NewArray<byte[]> sSkVu;
|
||||
private NewArray<byte[]> bVu;
|
||||
private int sIndex;
|
||||
private int indexB;
|
||||
|
||||
public String getFlID() {
|
||||
return flID;
|
||||
}
|
||||
|
||||
public void setFlID(String flID) {
|
||||
this.flID = flID;
|
||||
}
|
||||
|
||||
public NewArray<byte[]> getSSkVu() {
|
||||
return sSkVu;
|
||||
}
|
||||
|
||||
public void setSSkVu(NewArray<byte[]> sSkVu) {
|
||||
this.sSkVu = sSkVu;
|
||||
}
|
||||
|
||||
public NewArray<byte[]> getBVu() {
|
||||
return bVu;
|
||||
}
|
||||
|
||||
public void setBVu(NewArray<byte[]> bVu) {
|
||||
this.bVu = bVu;
|
||||
}
|
||||
|
||||
public int getSIndex() {
|
||||
return sIndex;
|
||||
}
|
||||
|
||||
public void setSIndex(int sIndex) {
|
||||
this.sIndex = sIndex;
|
||||
}
|
||||
|
||||
public int getIndexB() {
|
||||
return indexB;
|
||||
}
|
||||
|
||||
public void setIndexB(int indexB) {
|
||||
this.indexB = indexB;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.cipher.struct;
|
||||
|
||||
public class EncryptShare {
|
||||
private String flID;
|
||||
private NewArray<byte[]> share;
|
||||
|
||||
public String getFlID() {
|
||||
return flID;
|
||||
}
|
||||
|
||||
public void setFlID(String flID) {
|
||||
this.flID = flID;
|
||||
}
|
||||
|
||||
public NewArray<byte[]> getShare() {
|
||||
return share;
|
||||
}
|
||||
|
||||
public void setShare(NewArray<byte[]> share) {
|
||||
this.share = share;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.cipher.struct;
|
||||
|
||||
public class NewArray<T> {
|
||||
private int size;
|
||||
private T array;
|
||||
|
||||
public int getSize() {
|
||||
return size;
|
||||
}
|
||||
|
||||
public void setSize(int size) {
|
||||
this.size = size;
|
||||
}
|
||||
|
||||
public T getArray() {
|
||||
return array;
|
||||
}
|
||||
|
||||
public void setArray(T array) {
|
||||
this.array = array;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.cipher.struct;
|
||||
|
||||
public class ShareSecret {
|
||||
private String flID;
|
||||
private NewArray<byte[]> share;
|
||||
private int index;
|
||||
|
||||
public String getFlID() {
|
||||
return flID;
|
||||
}
|
||||
|
||||
public void setFlID(String flID) {
|
||||
this.flID = flID;
|
||||
}
|
||||
|
||||
public NewArray<byte[]> getShare() {
|
||||
return share;
|
||||
}
|
||||
|
||||
public void setShare(NewArray<byte[]> share) {
|
||||
this.share = share;
|
||||
}
|
||||
|
||||
public int getIndex() {
|
||||
return index;
|
||||
}
|
||||
|
||||
public void setIndex(int index) {
|
||||
this.index = index;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue