change dir and add Encrypt code

This commit is contained in:
zhoushan 2021-06-19 18:04:30 +08:00
parent f98497ca09
commit b7b58a2ab9
22 changed files with 2279 additions and 33 deletions

View File

@ -0,0 +1,578 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.cipher.AESEncrypt;
import com.mindspore.flclient.cipher.BaseUtil;
import com.mindspore.flclient.cipher.ClientListReq;
import com.mindspore.flclient.cipher.KEYAgreement;
import com.mindspore.flclient.cipher.Random;
import com.mindspore.flclient.cipher.ReconstructSecretReq;
import com.mindspore.flclient.cipher.ShareSecrets;
import com.mindspore.flclient.cipher.struct.ClientPublicKey;
import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
import com.mindspore.flclient.cipher.struct.EncryptShare;
import com.mindspore.flclient.cipher.struct.NewArray;
import com.mindspore.flclient.cipher.struct.ShareSecret;
import mindspore.schema.ClientShare;
import mindspore.schema.GetExchangeKeys;
import mindspore.schema.GetShareSecrets;
import mindspore.schema.RequestExchangeKeys;
import mindspore.schema.RequestShareSecrets;
import mindspore.schema.ResponseCode;
import mindspore.schema.ResponseExchangeKeys;
import mindspore.schema.ResponseShareSecrets;
import mindspore.schema.ReturnExchangeKeys;
import mindspore.schema.ReturnShareSecrets;
import java.io.UnsupportedEncodingException;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.security.NoSuchAlgorithmException;
import java.security.spec.InvalidKeySpecException;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.IVEC_LEN;
import static com.mindspore.flclient.LocalFLParameter.SEED_SIZE;
public class CipherClient {
private static final Logger LOGGER = Logger.getLogger(CipherClient.class.toString());
private FLCommunication flCommunication;
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private final int iteration;
private int featureSize;
private int t;
private List<byte[]> cKey = new ArrayList<>();
private List<byte[]> sKey = new ArrayList<>();
private byte[] bu;
private String nextRequestTime;
private Map<String, ClientPublicKey> clientPublicKeyList = new HashMap<String, ClientPublicKey>();
private Map<String, byte[]> sUVKeys = new HashMap<String, byte[]>();
private Map<String, byte[]> cUVKeys = new HashMap<String, byte[]>();
private List<EncryptShare> clientShareList = new ArrayList<>();
private List<EncryptShare> returnShareList = new ArrayList<>();
private float[] featureMask;
private List<String> u1ClientList = new ArrayList<>();
private List<String> u2UClientList = new ArrayList<>();
private List<String> u3ClientList = new ArrayList<>();
private List<DecryptShareSecrets> decryptShareSecretsList = new ArrayList<>();
private byte[] prime;
private KEYAgreement keyAgreement = new KEYAgreement();
private Random random = new Random();
private ClientListReq clientListReq = new ClientListReq();
private ReconstructSecretReq reconstructSecretReq = new ReconstructSecretReq();
public CipherClient(int iter, int minSecretNum, byte[] prime, int featureSize) {
flCommunication = FLCommunication.getInstance();
this.iteration = iter;
this.featureSize = featureSize;
this.t = minSecretNum;
this.prime = prime;
this.featureMask = new float[this.featureSize];
}
public void setNextRequestTime(String nextRequestTime) {
this.nextRequestTime = nextRequestTime;
}
public void setBU(byte[] bu) {
this.bu = bu;
}
public void setClientShareList(List<EncryptShare> clientShareList) {
this.clientShareList.clear();
this.clientShareList = clientShareList;
}
public String getNextRequestTime() {
return nextRequestTime;
}
public void genDHKeyPairs() {
byte[] csk = keyAgreement.generatePrivateKey();
byte[] cpk = keyAgreement.generatePublicKey(csk);
byte[] ssk = keyAgreement.generatePrivateKey();
byte[] spk = keyAgreement.generatePublicKey(ssk);
this.cKey.add(cpk);
this.cKey.add(csk);
this.sKey.add(spk);
this.sKey.add(ssk);
}
public void genIndividualSecret() {
byte[] key = new byte[SEED_SIZE];
random.getRandomBytes(key);
setBU(key);
}
public List<ShareSecret> genSecretShares(byte[] secret) throws UnsupportedEncodingException {
List<ShareSecret> shareSecretList = new ArrayList<>();
int size = u1ClientList.size();
ShareSecrets shamir = new ShareSecrets(t, size - 1);
ShareSecrets.SecretShare[] shares = shamir.split(secret, prime);
int j = 0;
for (int i = 0; i < size; i++) {
String vFlID = u1ClientList.get(i);
if (localFLParameter.getFlID().equals(vFlID)) {
continue;
} else {
ShareSecret shareSecret = new ShareSecret();
NewArray<byte[]> array = new NewArray<>();
int index = shares[j].getNum();
BigInteger intShare = shares[j].getShare();
byte[] share = BaseUtil.bigInteger2byteArray(intShare);
array.setSize(share.length);
array.setArray(share);
shareSecret.setFlID(vFlID);
shareSecret.setShare(array);
shareSecret.setIndex(index);
shareSecretList.add(shareSecret);
j += 1;
}
}
return shareSecretList;
}
public void genEncryptExchangedKeys() throws InvalidKeySpecException, NoSuchAlgorithmException {
cUVKeys.clear();
for (String key : clientPublicKeyList.keySet()) {
ClientPublicKey curPublicKey = clientPublicKeyList.get(key);
String vFlID = curPublicKey.getFlID();
if (localFLParameter.getFlID().equals(vFlID)) {
continue;
} else {
byte[] secret1 = keyAgreement.keyAgreement(cKey.get(1), curPublicKey.getCPK().getArray());
byte[] salt = new byte[0];
byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt);
cUVKeys.put(vFlID, secret);
}
}
}
public void encryptShares() throws Exception {
LOGGER.info(Common.addTag("[PairWiseMask] ************** generate encrypt share secrets for RequestShareSecrets **************"));
List<EncryptShare> encryptShareList = new ArrayList<>();
// connect sSkUv, bUV, sIndex, indexB and then Encrypt them
List<ShareSecret> sSkUv = genSecretShares(sKey.get(1));
List<ShareSecret> bUV = genSecretShares(bu);
for (int i = 0; i < bUV.size(); i++) {
EncryptShare encryptShare = new EncryptShare();
NewArray<byte[]> array = new NewArray<>();
String vFlID = bUV.get(i).getFlID();
byte[] sShare = sSkUv.get(i).getShare().getArray();
byte[] bShare = bUV.get(i).getShare().getArray();
byte[] sIndex = BaseUtil.integer2byteArray(sSkUv.get(i).getIndex());
byte[] bIndex = BaseUtil.integer2byteArray(bUV.get(i).getIndex());
byte[] allSecret = new byte[sShare.length + bShare.length + sIndex.length + bIndex.length + 4];
allSecret[0] = (byte) sShare.length;
allSecret[1] = (byte) bShare.length;
allSecret[2] = (byte) sIndex.length;
allSecret[3] = (byte) bIndex.length;
System.arraycopy(sIndex, 0, allSecret, 4, sIndex.length);
System.arraycopy(bIndex, 0, allSecret, 4 + sIndex.length, bIndex.length);
System.arraycopy(sShare, 0, allSecret, 4 + sIndex.length + bIndex.length, sShare.length);
System.arraycopy(bShare, 0, allSecret, 4 + sIndex.length + bIndex.length + sShare.length, bShare.length);
// encrypt:
byte[] iVecIn = new byte[IVEC_LEN];
AESEncrypt aesEncrypt = new AESEncrypt(cUVKeys.get(vFlID), iVecIn, "CBC");
byte[] encryptData = aesEncrypt.encrypt(cUVKeys.get(vFlID), allSecret);
array.setSize(encryptData.length);
array.setArray(encryptData);
encryptShare.setFlID(vFlID);
encryptShare.setShare(array);
encryptShareList.add(encryptShare);
}
setClientShareList(encryptShareList);
}
public float[] doubleMaskingWeight() throws Exception {
int size = u2UClientList.size();
List<Float> noiseBu = new ArrayList<>();
random.randomAESCTR(noiseBu, featureSize, bu);
float[] mask = new float[featureSize];
for (int i = 0; i < size; i++) {
String vFlID = u2UClientList.get(i);
ClientPublicKey curPublicKey = clientPublicKeyList.get(vFlID);
if (localFLParameter.getFlID().equals(vFlID)) {
continue;
} else {
byte[] salt = new byte[0];
byte[] secret1 = keyAgreement.keyAgreement(sKey.get(1), curPublicKey.getSPK().getArray());
byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt);
sUVKeys.put(vFlID, secret);
List<Float> noiseSuv = new ArrayList<>();
random.randomAESCTR(noiseSuv, featureSize, secret);
int sign;
if (localFLParameter.getFlID().compareTo(vFlID) > 0) {
sign = 1;
} else {
sign = -1;
}
for (int j = 0; j < noiseSuv.size(); j++) {
mask[j] = mask[j] + sign * noiseSuv.get(j);
}
}
}
for (int j = 0; j < noiseBu.size(); j++) {
mask[j] = mask[j] + noiseBu.get(j);
}
return mask;
}
public NewArray<byte[]> byteToArray(ByteBuffer buf, int size) {
NewArray<byte[]> newArray = new NewArray<>();
newArray.setSize(size);
byte[] array = new byte[size];
for (int i = 0; i < size; i++) {
byte word = buf.get();
array[i] = word;
}
newArray.setArray(array);
return newArray;
}
public FLClientStatus requestExchangeKeys() {
LOGGER.info(Common.addTag("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() + "=============="));
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[PairWiseMask] ==============requestExchangeKeys url: " + url + "=============="));
genDHKeyPairs();
byte[] cPK = cKey.get(0);
byte[] sPK = sKey.get(0);
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
int id = fbBuilder.createString(localFLParameter.getFlID());
int cpk = RequestExchangeKeys.createCPkVector(fbBuilder, cPK);
int spk = RequestExchangeKeys.createSPkVector(fbBuilder, sPK);
String dateTime = LocalDateTime.now().toString();
int time = fbBuilder.createString(dateTime);
int exchangeKeysRoot = RequestExchangeKeys.createRequestExchangeKeys(fbBuilder, id, cpk, spk, iteration, time);
fbBuilder.finish(exchangeKeysRoot);
byte[] msg = fbBuilder.sizedByteArray();
try {
byte[] responseData = flCommunication.syncRequest(url + "/exchangeKeys", msg);
ByteBuffer buffer = ByteBuffer.wrap(responseData);
ResponseExchangeKeys responseExchangeKeys = ResponseExchangeKeys.getRootAsResponseExchangeKeys(buffer);
FLClientStatus status = judgeRequestExchangeKeys(responseExchangeKeys);
return status;
} catch (Exception e) {
e.printStackTrace();
return FLClientStatus.FAILED;
}
}
public FLClientStatus judgeRequestExchangeKeys(ResponseExchangeKeys bufData) {
int retcode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestExchangeKeys**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
switch (retcode) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys success"));
return FLClientStatus.SUCCESS;
case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys out of time: need wait and request startFLJob again"));
setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
case (ResponseCode.SystemError):
LOGGER.info(Common.addTag("[PairWiseMask] catch RequestError or SystemError in RequestExchangeKeys"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ResponseExchangeKeys is invalid: " + retcode));
return FLClientStatus.FAILED;
}
}
public FLClientStatus getExchangeKeys() {
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[PairWiseMask] ==============getExchangeKeys url: " + url + "=============="));
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
int id = fbBuilder.createString(localFLParameter.getFlID());
String dateTime = LocalDateTime.now().toString();
int time = fbBuilder.createString(dateTime);
int getExchangeKeysRoot = GetExchangeKeys.createGetExchangeKeys(fbBuilder, id, iteration, time);
fbBuilder.finish(getExchangeKeysRoot);
byte[] msg = fbBuilder.sizedByteArray();
try {
byte[] responseData = flCommunication.syncRequest(url + "/getKeys", msg);
ByteBuffer buffer = ByteBuffer.wrap(responseData);
ReturnExchangeKeys returnExchangeKeys = ReturnExchangeKeys.getRootAsReturnExchangeKeys(buffer);
FLClientStatus status = judgeGetExchangeKeys(returnExchangeKeys);
return status;
} catch (Exception e) {
e.printStackTrace();
return FLClientStatus.FAILED;
}
}
public FLClientStatus judgeGetExchangeKeys(ReturnExchangeKeys bufData) {
int retcode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetExchangeKeys**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
switch (retcode) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[PairWiseMask] GetExchangeKeys success"));
clientPublicKeyList.clear();
u1ClientList.clear();
int length = bufData.remotePublickeysLength();
for (int i = 0; i < length; i++) {
ClientPublicKey publicKey = new ClientPublicKey();
String srcFlId = bufData.remotePublickeys(i).flId();
publicKey.setFlID(srcFlId);
ByteBuffer bufCpk = bufData.remotePublickeys(i).cPkAsByteBuffer();
int sizeCpk = bufData.remotePublickeys(i).cPkLength();
ByteBuffer bufSpk = bufData.remotePublickeys(i).sPkAsByteBuffer();
int sizeSpk = bufData.remotePublickeys(i).sPkLength();
publicKey.setCPK(byteToArray(bufCpk, sizeCpk));
publicKey.setSPK(byteToArray(bufSpk, sizeSpk));
clientPublicKeyList.put(srcFlId, publicKey);
u1ClientList.add(srcFlId);
}
return FLClientStatus.SUCCESS;
case (ResponseCode.SucNotReady):
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetExchangeKeys again!"));
return FLClientStatus.WAIT;
case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] GetExchangeKeys out of time: need wait and request startFLJob again"));
setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
case (ResponseCode.SystemError):
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetExchangeKeys"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReturnExchangeKeys is invalid: " + retcode));
return FLClientStatus.FAILED;
}
}
public FLClientStatus requestShareSecrets() throws Exception {
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[PairWiseMask] ==============requestShareSecrets url: " + url + "=============="));
genIndividualSecret();
genEncryptExchangedKeys();
encryptShares();
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
int id = fbBuilder.createString(localFLParameter.getFlID());
String dateTime = LocalDateTime.now().toString();
int time = fbBuilder.createString(dateTime);
int clientShareSize = clientShareList.size();
if (clientShareSize <= 0) {
LOGGER.warning(Common.addTag("[PairWiseMask] encrypt shares is not ready now!"));
Common.sleep(SLEEP_TIME);
FLClientStatus status = requestShareSecrets();
return status;
} else {
int[] add = new int[clientShareSize];
for (int i = 0; i < clientShareSize; i++) {
int flID = fbBuilder.createString(clientShareList.get(i).getFlID());
int shareSecretFbs = ClientShare.createShareVector(fbBuilder, clientShareList.get(i).getShare().getArray());
ClientShare.startClientShare(fbBuilder);
ClientShare.addFlId(fbBuilder, flID);
ClientShare.addShare(fbBuilder, shareSecretFbs);
int clientShareRoot = ClientShare.endClientShare(fbBuilder);
add[i] = clientShareRoot;
}
int encryptedSharesFbs = RequestShareSecrets.createEncryptedSharesVector(fbBuilder, add);
int requestShareSecretsRoot = RequestShareSecrets.createRequestShareSecrets(fbBuilder, id, encryptedSharesFbs, iteration, time);
fbBuilder.finish(requestShareSecretsRoot);
byte[] msg = fbBuilder.sizedByteArray();
try {
byte[] responseData = flCommunication.syncRequest(url + "/shareSecrets", msg);
ByteBuffer buffer = ByteBuffer.wrap(responseData);
ResponseShareSecrets responseShareSecrets = ResponseShareSecrets.getRootAsResponseShareSecrets(buffer);
FLClientStatus status = judgeRequestShareSecrets(responseShareSecrets);
return status;
} catch (Exception e) {
e.printStackTrace();
return FLClientStatus.FAILED;
}
}
}
public FLClientStatus judgeRequestShareSecrets(ResponseShareSecrets bufData) {
int retcode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestShareSecrets**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
switch (retcode) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets success"));
return FLClientStatus.SUCCESS;
case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets out of time: need wait and request startFLJob again"));
setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
case (ResponseCode.SystemError):
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in RequestShareSecrets"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ResponseShareSecrets is invalid: " + retcode));
return FLClientStatus.FAILED;
}
}
public FLClientStatus getShareSecrets() {
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[PairWiseMask] ==============getShareSecrets url: " + url + "=============="));
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
int id = fbBuilder.createString(localFLParameter.getFlID());
String dateTime = LocalDateTime.now().toString();
int time = fbBuilder.createString(dateTime);
int getShareSecrets = GetShareSecrets.createGetShareSecrets(fbBuilder, id, iteration, time);
fbBuilder.finish(getShareSecrets);
byte[] msg = fbBuilder.sizedByteArray();
try {
byte[] responseData = flCommunication.syncRequest(url + "/getSecrets", msg);
ByteBuffer buffer = ByteBuffer.wrap(responseData);
ReturnShareSecrets returnShareSecrets = ReturnShareSecrets.getRootAsReturnShareSecrets(buffer);
FLClientStatus status = judgeGetShareSecrets(returnShareSecrets);
return status;
} catch (Exception e) {
e.printStackTrace();
return FLClientStatus.FAILED;
}
}
public FLClientStatus judgeGetShareSecrets(ReturnShareSecrets bufData) {
int retcode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetShareSecrets**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
LOGGER.info(Common.addTag("[PairWiseMask] the size of encrypted shares: " + bufData.encryptedSharesLength()));
switch (retcode) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[PairWiseMask] GetShareSecrets success"));
returnShareList.clear();
u2UClientList.clear();
int length = bufData.encryptedSharesLength();
for (int i = 0; i < length; i++) {
EncryptShare shareSecret = new EncryptShare();
shareSecret.setFlID(bufData.encryptedShares(i).flId());
ByteBuffer bufShare = bufData.encryptedShares(i).shareAsByteBuffer();
int sizeShare = bufData.encryptedShares(i).shareLength();
shareSecret.setShare(byteToArray(bufShare, sizeShare));
returnShareList.add(shareSecret);
u2UClientList.add(bufData.encryptedShares(i).flId());
}
return FLClientStatus.SUCCESS;
case (ResponseCode.SucNotReady):
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetShareSecrets again!"));
return FLClientStatus.WAIT;
case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] GetShareSecrets out of time: need wait and request startFLJob again"));
setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
case (ResponseCode.SystemError):
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetShareSecrets"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReturnShareSecrets is invalid: " + retcode));
return FLClientStatus.FAILED;
}
}
public FLClientStatus exchangeKeys() {
LOGGER.info(Common.addTag("[PairWiseMask] ==================== round0: RequestExchangeKeys+GetExchangeKeys ======================"));
FLClientStatus curStatus;
// RequestExchangeKeys
curStatus = requestExchangeKeys();
while (curStatus == FLClientStatus.WAIT) {
Common.sleep(SLEEP_TIME);
curStatus = requestExchangeKeys();
}
if (curStatus != FLClientStatus.SUCCESS) {
return curStatus;
}
// GetExchangeKeys
curStatus = getExchangeKeys();
while (curStatus == FLClientStatus.WAIT) {
Common.sleep(SLEEP_TIME);
curStatus = getExchangeKeys();
}
return curStatus;
}
public FLClientStatus shareSecrets() throws Exception {
LOGGER.info(Common.addTag(("[PairWiseMask] ==================== round1: RequestShareSecrets+GetShareSecrets ======================")));
FLClientStatus curStatus;
// RequestShareSecrets
curStatus = requestShareSecrets();
while (curStatus == FLClientStatus.WAIT) {
Common.sleep(SLEEP_TIME);
curStatus = requestShareSecrets();
}
if (curStatus != FLClientStatus.SUCCESS) {
return curStatus;
}
// GetShareSecrets
curStatus = getShareSecrets();
while (curStatus == FLClientStatus.WAIT) {
Common.sleep(SLEEP_TIME);
curStatus = getShareSecrets();
}
return curStatus;
}
public FLClientStatus reconstructSecrets() {
LOGGER.info(Common.addTag("[PairWiseMask] =================== round3: GetClientList+SendReconstructSecret ========================"));
FLClientStatus curStatus;
// GetClientList
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList, cUVKeys);
while (curStatus == FLClientStatus.WAIT) {
Common.sleep(SLEEP_TIME);
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList, cUVKeys);
}
if (curStatus == FLClientStatus.RESTART) {
nextRequestTime = clientListReq.getNextRequestTime();
}
if (curStatus != FLClientStatus.SUCCESS) {
return curStatus;
}
// SendReconstructSecret
curStatus = reconstructSecretReq.sendReconstructSecret(decryptShareSecretsList, u3ClientList, iteration);
while (curStatus == FLClientStatus.WAIT) {
Common.sleep(SLEEP_TIME);
curStatus = reconstructSecretReq.sendReconstructSecret(decryptShareSecretsList, u3ClientList, iteration);
}
if (curStatus == FLClientStatus.RESTART) {
nextRequestTime = reconstructSecretReq.getNextRequestTime();
}
return curStatus;
}
}

View File

@ -0,0 +1,125 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.Random;
import java.util.logging.Logger;
public class Common {
public static final String LOG_TITLE = "<FLClient> ";
private static final Logger LOGGER = Logger.getLogger(Common.class.toString());
private static List<String> flNameTrustList = new ArrayList<>(Arrays.asList("lenet", "adbert"));
public static String generateUrl(boolean useElb, String ip, int port, int serverNum) {
String url;
if (useElb) {
Random rand = new Random();
int randomNum = rand.nextInt(100000) % serverNum + port;
url = ip + String.valueOf(randomNum);
} else {
url = ip + String.valueOf(port);
}
return url;
}
public static void setClassifierWeightName(List<String> classifierWeightName) {
classifierWeightName.add("albert.pooler.weight");
classifierWeightName.add("albert.pooler.bias");
classifierWeightName.add("classifier.weight");
classifierWeightName.add("classifier.bias");
LOGGER.info(addTag("classifierWeightName size: " + classifierWeightName.size()));
}
public static void setAlbertWeightName(List<String> albertWeightName) {
albertWeightName.add("albert.encoder.embedding_hidden_mapping_in.weight");
albertWeightName.add("albert.encoder.embedding_hidden_mapping_in.bias");
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.weight");
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.bias");
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.weight");
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.bias");
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.weight");
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.bias");
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.weight");
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.bias");
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.gamma");
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.beta");
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.weight");
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.bias");
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.weight");
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.bias");
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.gamma");
albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.beta");
LOGGER.info(addTag("albertWeightName size: " + albertWeightName.size()));
}
public static boolean checkFLName(String flName) {
return (flNameTrustList.contains(flName));
}
public static void sleep(long millis) {
try {
Thread.sleep(millis); //1000 milliseconds is one second.
} catch (InterruptedException ex) {
LOGGER.severe(addTag("[sleep] catch InterruptedException: " + ex.getMessage()));
Thread.currentThread().interrupt();
}
}
public static long getWaitTime(String nextRequestTime) {
Date date = new Date();
long currentTime = date.getTime();
long waitTime = 0;
if (!nextRequestTime.equals("")) {
waitTime = Math.max(0, Long.valueOf(nextRequestTime) - currentTime);
}
LOGGER.info(addTag("[getWaitTime] next request time stamp: " + nextRequestTime + " current time stamp: " + currentTime));
LOGGER.info(addTag("[getWaitTime] waitTime: " + waitTime));
return waitTime;
}
public static long startTime(String tag) {
Date startDate = new Date();
long startTime = startDate.getTime();
LOGGER.info(addTag("[start time] <" + tag + "> start time: " + startTime));
return startTime;
}
public static void endTime(long start, String tag) {
Date endDate = new Date();
long endTime = endDate.getTime();
LOGGER.info(addTag("[end time] <" + tag + "> end time: " + endTime));
LOGGER.info(addTag("[interval time] <" + tag + "> interval time(ms): " + (endTime - start)));
}
public static String addTag(String message) {
return LOG_TITLE + message;
}
public static boolean isAutoscaling(byte[] message, String autoscalingTag) {
return (new String(message)).contains(autoscalingTag);
}
public static boolean checkPath(String path) {
File file = new File(path);
return file.exists();
}
}

View File

@ -1,12 +1,12 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* <p>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
*
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

View File

@ -1,12 +1,12 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* <p>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
*
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

View File

@ -0,0 +1,163 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import okhttp3.Call;
import okhttp3.Callback;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.io.IOException;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.TIME_OUT;
public class FLCommunication implements IFLCommunication {
private static int timeOut;
private static boolean ssl = false;
private FLParameter flParameter = FLParameter.getInstance();
private static final MediaType MEDIA_TYPE_JSON = MediaType.parse("applicatiom/json;charset=utf-8");
private static final Logger LOGGER = Logger.getLogger(FLCommunication.class.toString());
private OkHttpClient client;
private static FLCommunication communication;
private FLCommunication() {
if (flParameter.getTimeOut() != 0) {
timeOut = flParameter.getTimeOut();
} else {
timeOut = TIME_OUT;
}
ssl = flParameter.isUseSSL();
client = getUnsafeOkHttpClient();
}
private static OkHttpClient getUnsafeOkHttpClient() {
X509TrustManager trustManager = new X509TrustManager() {
@Override
public X509Certificate[] getAcceptedIssuers() {
return new X509Certificate[]{};
}
@Override
public void checkServerTrusted(X509Certificate[] arg0, String arg1) throws CertificateException {
}
@Override
public void checkClientTrusted(X509Certificate[] arg0, String arg1) throws CertificateException {
}
};
final TrustManager[] trustAllCerts = new TrustManager[]{trustManager};
try {
LOGGER.info(Common.addTag("the set timeOut in OkHttpClient: " + timeOut));
OkHttpClient.Builder builder = new OkHttpClient.Builder();
builder.connectTimeout(timeOut, TimeUnit.SECONDS);
builder.writeTimeout(timeOut, TimeUnit.SECONDS);
builder.readTimeout(3 * timeOut, TimeUnit.SECONDS);
if (ssl) {
builder.sslSocketFactory(SSLSocketFactoryTools.getInstance().getmSslSocketFactory(), SSLSocketFactoryTools.getInstance().getmTrustManager());
builder.hostnameVerifier(SSLSocketFactoryTools.getInstance().getHostnameVerifier());
} else {
final SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(null, trustAllCerts, new java.security.SecureRandom());
final javax.net.ssl.SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory();
builder.sslSocketFactory(sslSocketFactory, trustManager);
builder.hostnameVerifier(new HostnameVerifier() {
@Override
public boolean verify(String arg0, SSLSession arg1) {
return true;
}
});
}
return builder.build();
} catch (NoSuchAlgorithmException | KeyManagementException e) {
LOGGER.severe(Common.addTag("[OkHttpClient] catch NoSuchAlgorithmException or KeyManagementException: " + e.getMessage()));
throw new RuntimeException(e);
}
}
public static FLCommunication getInstance() {
if (communication == null) {
synchronized (FLCommunication.class) {
if (communication == null) {
communication = new FLCommunication();
}
}
}
return communication;
}
@Override
public void setTimeOut(int timeout) throws TimeoutException {
}
@Override
public byte[] syncRequest(String url, byte[] msg) throws IOException {
Request request = new Request.Builder()
.url(url)
.post(RequestBody.create(MEDIA_TYPE_JSON, msg)).build();
Response response = this.client.newCall(request).execute();
if (!response.isSuccessful()) {
throw new IOException("Unexpected code " + response);
}
return response.body().bytes();
}
@Override
public void asyncRequest(String url, byte[] msg, IAsyncCallBack callBack) throws Exception {
Request request = new Request.Builder()
.url(url)
.header("Accept", "application/proto")
.header("Content-Type", "application/proto; charset=utf-8")
.post(RequestBody.create(MEDIA_TYPE_JSON, msg)).build();
client.newCall(request).enqueue(new Callback() {
IAsyncCallBack asyncCallBack = callBack;
@Override
public void onResponse(Call call, Response response) throws IOException {
asyncCallBack.onResponse(response.body().bytes());
call.cancel();
}
@Override
public void onFailure(Call call, IOException e) {
asyncCallBack.onFailure(e);
call.cancel();
}
});
}
}

View File

@ -1,12 +1,12 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* <p>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
*
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@ -17,13 +17,15 @@ package com.mindspore.flclient;
import java.util.logging.Logger;
public class FLJobResultCallback implements IFLJobResultCallback{
private static final Logger logger = Logger.getLogger(FLJobResultCallback.class.toString());
public class FLJobResultCallback implements IFLJobResultCallback {
private static final Logger LOGGER = Logger.getLogger(FLJobResultCallback.class.toString());
public void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode) {
logger.info(Common.addTag("[onFlJobIterationFinished] modelName: " + modelName + " iterationSeq: " + iterationSeq + " resultCode: " + resultCode));
}
public void onFlJobFinished(String modelName, int iterationCount, int resultCode) {
logger.info(Common.addTag("[onFlJobFinished] modelName: " + modelName + " iterationCount: " + iterationCount + " resultCode: " + resultCode));
LOGGER.info(Common.addTag("[onFlJobIterationFinished] modelName: " + modelName + " iterationSeq: " + iterationSeq + " resultCode: " + resultCode));
}
}
public void onFlJobFinished(String modelName, int iterationCount, int resultCode) {
LOGGER.info(Common.addTag("[onFlJobFinished] modelName: " + modelName + " iterationCount: " + iterationCount + " resultCode: " + resultCode));
}
}

View File

@ -1,12 +1,12 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* <p>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
*
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

View File

@ -1,12 +1,12 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* <p>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
*
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@ -15,10 +15,18 @@
*/
package com.mindspore.flclient;
import javax.net.ssl.*;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.io.FileInputStream;
import java.io.InputStream;
import java.security.*;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.SignatureException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
@ -32,11 +40,12 @@ public class SSLSocketFactoryTools {
private SSLContext sslContext;
private MyTrustManager myTrustManager;
private static SSLSocketFactoryTools instance;
private SSLSocketFactoryTools() {
initSslSocketFactory();
}
private void initSslSocketFactory(){
private void initSslSocketFactory() {
try {
sslContext = SSLContext.getInstance("TLS");
x509Certificate = readCert(flParameter.getCertPath());
@ -51,16 +60,14 @@ public class SSLSocketFactoryTools {
}
}
public static SSLSocketFactoryTools getInstance() {
if (instance == null) {
instance=new SSLSocketFactoryTools();
instance = new SSLSocketFactoryTools();
}
return instance;
}
public X509Certificate readCert(String assetName) {
public X509Certificate readCert(String assetName) {
InputStream inputStream = null;
try {
inputStream = new FileInputStream(assetName);
@ -110,7 +117,6 @@ public class SSLSocketFactoryTools {
public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
}
@Override
public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
for (X509Certificate cert : chain) {
@ -130,6 +136,7 @@ public class SSLSocketFactoryTools {
} catch (SignatureException e) {
logger.severe(Common.addTag("[SSLSocketFactoryTools] catch SignatureException in checkServerTrusted: " + e.getMessage()));
}
logger.info(Common.addTag("checkServerTrusted success!"));
}
}

View File

@ -0,0 +1,323 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AdTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.FeatureMap;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.logging.Logger;
public class SecureProtocol {
private static final Logger LOGGER = Logger.getLogger(SecureProtocol.class.toString());
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private int iteration;
private CipherClient cipher;
private FLClientStatus status;
private float[] featureMask;
private double dpEps;
private double dpDelta;
private double dpNormClip;
private static double deltaError = 1e-6;
private static Map<String, float[]> modelMap;
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
public FLClientStatus getStatus() {
return status;
}
public float[] getFeatureMask() {
return featureMask;
}
public SecureProtocol() {
}
public void setPWParameter(int iter, int minSecretNum, byte[] prime, int featureSize) {
this.iteration = iter;
this.cipher = new CipherClient(iteration, minSecretNum, prime, featureSize);
}
public FLClientStatus setDPParameter(int iter, double diffEps,
double diffDelta, double diffNorm, Map<String, float[]> map) {
try {
this.iteration = iter;
this.dpEps = diffEps;
this.dpDelta = diffDelta;
this.dpNormClip = diffNorm;
this.modelMap = map;
status = FLClientStatus.SUCCESS;
} catch (Exception e) {
LOGGER.severe(Common.addTag("[DPEncrypt] catch Exception in setDPParameter: " + e.getMessage()));
status = FLClientStatus.FAILED;
}
return status;
}
public ArrayList<String> getEncryptFeatureName() {
return encryptFeatureName;
}
public void setEncryptFeatureName(ArrayList<String> encryptFeatureName) {
this.encryptFeatureName = encryptFeatureName;
}
public String getNextRequestTime() {
return cipher.getNextRequestTime();
}
public FLClientStatus pwCreateMask() {
LOGGER.info("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() + "==============");
// round 0
status = cipher.exchangeKeys();
LOGGER.info("[PairWiseMask] ============= RequestExchangeKeys+GetExchangeKeys response: " + status + "============");
if (status != FLClientStatus.SUCCESS) {
return status;
}
// round 1
try {
status = cipher.shareSecrets();
LOGGER.info("[Encrypt] =============RequestShareSecrets+GetShareSecrets response: " + status + "=============");
} catch (Exception e) {
LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask");
status = FLClientStatus.FAILED;
}
if (status != FLClientStatus.SUCCESS) {
return status;
}
// round2
try {
featureMask = cipher.doubleMaskingWeight();
LOGGER.info("[Encrypt] =============Create double feature mask: SUCCESS=============");
} catch (Exception e) {
LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask");
status = FLClientStatus.FAILED;
}
return status;
}
public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize) {
LOGGER.info("[Encrypt] feature mask size: " + featureMask.length);
// get feature map
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals("adbert")) {
AdTrainBert adTrainBert = AdTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
} else if (flParameter.getFlName().equals("lenet")) {
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
}
int featureSize = encryptFeatureName.size();
int[] featuresMap = new int[featureSize];
int maskIndex = 0;
for (int i = 0; i < featureSize; i++) {
String key = encryptFeatureName.get(i);
float[] data = map.get(key);
LOGGER.info("[Encrypt] feature name: " + key + " feature size: " + data.length);
for (int j = 0; j < data.length; j++) {
float rawData = data[j];
float maskData = rawData * trainDataSize + featureMask[maskIndex];
maskIndex += 1;
data[j] = maskData;
}
int featureName = builder.createString(key);
int weight = FeatureMap.createDataVector(builder, data);
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
featuresMap[i] = featureMap;
}
return featuresMap;
}
public FLClientStatus pwUnmasking() {
status = cipher.reconstructSecrets(); // round3
LOGGER.info("[Encrypt] =============GetClientList+SendReconstructSecret: " + status + "=============");
return status;
}
private static float calculateErf(double x) {
double result = 0;
int segmentNum = 10000;
double deltaX = x / segmentNum;
result += 1;
for (int i = 1; i < segmentNum; i++) {
result += 2 * Math.exp(-Math.pow(deltaX * i, 2));
}
result += Math.exp(-Math.pow(deltaX * segmentNum, 2));
return (float) (result * deltaX / Math.pow(Math.PI, 0.5));
}
private static double calculatePhi(double t) {
return 0.5 * (1.0 + calculateErf((t / Math.sqrt(2.0))));
}
private static double calculateBPositive(double eps, double s) {
return calculatePhi(Math.sqrt(eps * s)) - Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (s + 2.0)));
}
private static double calculateBNegative(double eps, double s) {
return calculatePhi(-Math.sqrt(eps * s)) - Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (s + 2.0)));
}
private static double calculateSPositive(double eps, double targetDelta, double sInf, double sSup) {
double deltaSup = calculateBPositive(eps, sSup);
while (deltaSup <= targetDelta) {
sInf = sSup;
sSup = 2 * sInf;
deltaSup = calculateBPositive(eps, sSup);
}
double sMid = sInf + (sSup - sInf) / 2.0;
int iterMax = 1000;
int iters = 0;
while (true) {
double b = calculateBPositive(eps, sMid);
if (b <= targetDelta) {
if (targetDelta - b <= deltaError) {
break;
} else {
sInf = sMid;
}
} else {
sSup = sMid;
}
sMid = sInf + (sSup - sInf) / 2.0;
iters += 1;
if (iters > iterMax) {
break;
}
}
return sMid;
}
private static double calculateSNegative(double eps, double targetDelta, double sInf, double sSup) {
double deltaSup = calculateBNegative(eps, sSup);
while (deltaSup > targetDelta) {
sInf = sSup;
sSup = 2 * sInf;
deltaSup = calculateBNegative(eps, sSup);
}
double sMid = sInf + (sSup - sInf) / 2.0;
int iterMax = 1000;
int iters = 0;
while (true) {
double b = calculateBNegative(eps, sMid);
if (b <= targetDelta) {
if (targetDelta - b <= deltaError) {
break;
} else {
sSup = sMid;
}
} else {
sInf = sMid;
}
sMid = sInf + (sSup - sInf) / 2.0;
iters += 1;
if (iters > iterMax) {
break;
}
}
return sMid;
}
private static double calculateSigma(double clipNorm, double eps, double targetDelta) {
double deltaZero = calculateBPositive(eps, 0);
double alpha = 1;
if (targetDelta > deltaZero) {
double s = calculateSPositive(eps, targetDelta, 0, 1);
alpha = Math.sqrt(1.0 + s / 2.0) - Math.sqrt(s / 2.0);
} else if (targetDelta < deltaZero) {
double s = calculateSNegative(eps, targetDelta, 0, 1);
alpha = Math.sqrt(1.0 + s / 2.0) + Math.sqrt(s / 2.0);
}
return alpha * clipNorm / Math.sqrt(2.0 * eps);
}
public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize) {
// get feature map
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals("adbert")) {
AdTrainBert adTrainBert = AdTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
} else if (flParameter.getFlName().equals("lenet")) {
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
}
Map<String, float[]> mapBeforeTrain = modelMap;
int featureSize = encryptFeatureName.size();
// calculate sigma
double gaussianSigma = calculateSigma(dpNormClip, dpEps, dpDelta);
LOGGER.info(Common.addTag("[Encrypt] =============Noise sigma of DP is: " + gaussianSigma + "============="));
// prepare gaussian noise
SecureRandom random = new SecureRandom();
int randomInt = random.nextInt();
Random r = new Random(randomInt);
// calculate l2-norm of all layers' update array
double updateL2Norm = 0;
for (int i = 0; i < featureSize; i++) {
String key = encryptFeatureName.get(i);
float[] data = map.get(key);
float[] dataBeforeTrain = mapBeforeTrain.get(key);
for (int j = 0; j < data.length; j++) {
float rawData = data[j];
float rawDataBeforeTrain = dataBeforeTrain[j];
float updateData = rawData - rawDataBeforeTrain;
updateL2Norm += updateData * updateData;
}
}
updateL2Norm = Math.sqrt(updateL2Norm);
double clipFactor = Math.min(1.0, dpNormClip / updateL2Norm);
// clip and add noise
int[] featuresMap = new int[featureSize];
for (int i = 0; i < featureSize; i++) {
String key = encryptFeatureName.get(i);
float[] data = map.get(key);
float[] data2 = new float[data.length];
float[] dataBeforeTrain = mapBeforeTrain.get(key);
for (int j = 0; j < data.length; j++) {
float rawData = data[j];
float rawDataBeforeTrain = dataBeforeTrain[j];
float updateData = rawData - rawDataBeforeTrain;
// clip
updateData *= clipFactor;
// add noise
double gaussianNoise = r.nextGaussian() * gaussianSigma;
updateData += gaussianNoise;
data2[j] = rawDataBeforeTrain + updateData;
data2[j] = data2[j] * trainDataSize;
}
int featureName = builder.createString(key);
int weight = FeatureMap.createDataVector(builder, data2);
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
featuresMap[i] = featureMap;
}
return featuresMap;
}
}

View File

@ -1,12 +1,12 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* <p>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
*
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

View File

@ -0,0 +1,104 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher;
import com.mindspore.flclient.Common;
import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.io.UnsupportedEncodingException;
import java.util.logging.Logger;
public class AESEncrypt {
private static final Logger LOGGER = Logger.getLogger(AESEncrypt.class.toString());
/**
* 128, 192 or 256
*/
private static final int KEY_SIZE = 256;
private static final int I_VEC_LEN = 16;
/**
* encrypt/decrypt algorithm name
*/
private static final String ALGORITHM = "AES";
/**
* algorithm/Mode/padding mode
*/
private static final String CIPHER_MODE_CTR = "AES/CTR/NoPadding";
private static final String CIPHER_MODE_CBC = "AES/CBC/PKCS5PADDING";
private String CIPHER_MODE;
private static final int RANDOM_LEN = KEY_SIZE / 8;
private String iVecS = "1111111111111111";
private byte[] iVec = iVecS.getBytes("utf-8");
public AESEncrypt(byte[] key, byte[] iVecIn, String mode) throws UnsupportedEncodingException {
if (key == null) {
LOGGER.severe(Common.addTag("Key is null"));
return;
}
if (key.length != KEY_SIZE / 8) {
LOGGER.severe(Common.addTag("the length of key is not correct"));
return;
}
if (mode.contains("CBC")) {
CIPHER_MODE = CIPHER_MODE_CBC;
} else if (mode.contains("CTR")) {
CIPHER_MODE = CIPHER_MODE_CTR;
} else {
return;
}
if (iVecIn == null || iVecIn.length != I_VEC_LEN) {
return;
}
iVec = iVecIn;
}
public byte[] encrypt(byte[] key, byte[] data) throws Exception {
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
Cipher cipher = Cipher.getInstance(CIPHER_MODE);
IvParameterSpec iv = new IvParameterSpec(iVec);
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv);
byte[] encrypted = cipher.doFinal(data);
String encryptResultStr = BaseUtil.byte2HexString(encrypted);
return encrypted;
}
public byte[] encryptCTR(byte[] key, byte[] data) throws Exception {
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
Cipher cipher = Cipher.getInstance(CIPHER_MODE);
IvParameterSpec iv = new IvParameterSpec(iVec);
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv);
byte[] encrypted = cipher.doFinal(data);
return encrypted;
}
public byte[] decrypt(byte[] key, byte[] encryptData) throws Exception {
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
Cipher cipher = Cipher.getInstance(CIPHER_MODE);
IvParameterSpec iv = new IvParameterSpec(iVec);
cipher.init(Cipher.DECRYPT_MODE, skeySpec, iv);
byte[] origin = cipher.doFinal(encryptData);
return origin;
}
}

View File

@ -0,0 +1,143 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher;
import java.io.UnsupportedEncodingException;
import java.math.BigInteger;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;
public class BaseUtil {
private static final char[] HEX_DIGITS = new char[]{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'};
public BaseUtil() {
}
public static String byte2HexString(byte[] bytes) {
if (null == bytes) {
return null;
} else if (bytes.length == 0) {
return "";
} else {
char[] chars = new char[bytes.length * 2];
for (int i = 0; i < bytes.length; ++i) {
int b = bytes[i];
chars[i * 2] = HEX_DIGITS[(b & 240) >> 4];
chars[i * 2 + 1] = HEX_DIGITS[b & 15];
}
return new String(chars);
}
}
public static byte[] hexString2ByteArray(String str) {
int length = str.length() / 2;
byte[] bytes = new byte[length];
byte[] source = str.getBytes(Charset.forName("UTF-8"));
for (int i = 0; i < bytes.length; ++i) {
byte bh = Byte.decode("0x" + new String(new byte[]{source[i * 2]}, Charset.forName("UTF-8")));
bh = (byte) (bh << 4);
byte bl = Byte.decode("0x" + new String(new byte[]{source[i * 2 + 1]}, Charset.forName("UTF-8")));
bytes[i] = (byte) (bh ^ bl);
}
return bytes;
}
public static BigInteger byteArray2BigInteger(byte[] bytes) {
BigInteger bigInteger = BigInteger.ZERO;
for (int i = 0; i < bytes.length; ++i) {
int intI = bytes[i];
if (intI < 0) {
intI = intI + 256;
}
BigInteger bi = new BigInteger(String.valueOf(intI));
bigInteger = bigInteger.multiply(BigInteger.valueOf(256)).add(bi);
}
return bigInteger;
}
public static BigInteger string2BigInteger(String str) throws UnsupportedEncodingException {
StringBuilder res = new StringBuilder();
byte[] bytes = String.valueOf(str).getBytes("UTF-8");
BigInteger bigInteger = BigInteger.ZERO;
for (int i = 0; i < str.length(); ++i) {
BigInteger bi = new BigInteger(String.valueOf(bytes[i]));
bigInteger = bigInteger.multiply(BigInteger.valueOf(256)).add(bi);
}
return bigInteger;
}
public static String bigInteger2String(BigInteger bigInteger) throws UnsupportedEncodingException {
StringBuilder res = new StringBuilder();
List<Integer> lists = new ArrayList<>();
BigInteger bi = bigInteger;
BigInteger DIV = BigInteger.valueOf(256);
while (bi.compareTo(BigInteger.ZERO) > 0) {
lists.add(bi.mod(DIV).intValue());
bi = bi.divide(DIV);
}
for (int i = lists.size() - 1; i >= 0; --i) {
res.append((char) (int) (lists.get(i)));
}
return res.toString();
}
public static byte[] bigInteger2byteArray(BigInteger bigInteger) throws UnsupportedEncodingException {
List<Integer> lists = new ArrayList<>();
BigInteger bi = bigInteger;
BigInteger DIV = BigInteger.valueOf(256);
while (bi.compareTo(BigInteger.ZERO) > 0) {
lists.add(bi.mod(DIV).intValue());
bi = bi.divide(DIV);
}
byte[] res = new byte[lists.size()];
for (int i = lists.size() - 1; i >= 0; --i) {
res[lists.size() - i - 1] = ((byte) (int) (lists.get(i)));
}
return res;
}
public static byte[] integer2byteArray(Integer num) {
List<Integer> lists = new ArrayList<>();
Integer bi = num;
Integer DIV = 256;
while (bi > 0) {
lists.add(bi % DIV);
bi = bi / DIV;
}
byte[] res = new byte[lists.size()];
for (int i = lists.size() - 1; i >= 0; --i) {
res[lists.size() - i - 1] = ((byte) (int) (lists.get(i)));
}
return res;
}
public static Integer byteArray2Integer(byte[] bytes) {
Integer num = 0;
for (int i = 0; i < bytes.length; ++i) {
int intI = bytes[i];
if (intI < 0) {
intI = intI + 256;
}
num = num * 256 + intI;
}
return num;
}
}

View File

@ -0,0 +1,158 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.Common;
import com.mindspore.flclient.FLClientStatus;
import com.mindspore.flclient.FLCommunication;
import com.mindspore.flclient.FLParameter;
import com.mindspore.flclient.LocalFLParameter;
import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
import com.mindspore.flclient.cipher.struct.EncryptShare;
import com.mindspore.flclient.cipher.struct.NewArray;
import mindspore.schema.GetClientList;
import mindspore.schema.ResponseCode;
import mindspore.schema.ReturnClientList;
import java.nio.ByteBuffer;
import java.time.LocalDateTime;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.IVEC_LEN;
public class ClientListReq {
private static final Logger LOGGER = Logger.getLogger(ClientListReq.class.toString());
private FLCommunication flCommunication;
private String nextRequestTime;
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
public ClientListReq() {
flCommunication = FLCommunication.getInstance();
}
public String getNextRequestTime() {
return nextRequestTime;
}
public void setNextRequestTime(String nextRequestTime) {
this.nextRequestTime = nextRequestTime;
}
public FLClientStatus getClientList(int iteration, List<String> u3ClientList, List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[PairWiseMask] ==============getClientList url: " + url + "=============="));
FlatBufferBuilder builder = new FlatBufferBuilder();
int id = builder.createString(localFLParameter.getFlID());
String dateTime = LocalDateTime.now().toString();
int time = builder.createString(dateTime);
int clientListRoot = GetClientList.createGetClientList(builder, id, iteration, time);
builder.finish(clientListRoot);
byte[] msg = builder.sizedByteArray();
try {
byte[] responseData = flCommunication.syncRequest(url + "/getClientList", msg);
ByteBuffer buffer = ByteBuffer.wrap(responseData);
ReturnClientList clientListRsp = ReturnClientList.getRootAsReturnClientList(buffer);
FLClientStatus status = judgeGetClientList(clientListRsp, u3ClientList, decryptSecretsList, returnShareList, cuvKeys);
return status;
} catch (Exception e) {
e.printStackTrace();
return FLClientStatus.FAILED;
}
}
public FLClientStatus judgeGetClientList(ReturnClientList bufData, List<String> u3ClientList, List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
int retcode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] ************** the response of GetClientList **************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
LOGGER.info(Common.addTag("[PairWiseMask] the size of clients: " + bufData.clientsLength()));
FLClientStatus status;
switch (retcode) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[PairWiseMask] GetClientList success"));
u3ClientList.clear();
int clientSize = bufData.clientsLength();
for (int i = 0; i < clientSize; i++) {
String curFlId = bufData.clients(i);
u3ClientList.add(curFlId);
}
try {
decryptSecretShares(decryptSecretsList, returnShareList, cuvKeys);
} catch (Exception e) {
e.printStackTrace();
return FLClientStatus.FAILED;
}
return FLClientStatus.SUCCESS;
case (ResponseCode.SucNotReady):
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetClientList again!"));
return FLClientStatus.WAIT;
case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] GetClientList out of time: need wait and request startFLJob again"));
setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
case (ResponseCode.SystemError):
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetClientList"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReturnClientList is invalid: " + retcode));
return FLClientStatus.FAILED;
}
}
public void decryptSecretShares(List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) throws Exception {
decryptSecretsList.clear();
int size = returnShareList.size();
for (int i = 0; i < size; i++) {
DecryptShareSecrets decryptShareSecrets = new DecryptShareSecrets();
EncryptShare encryptShare = returnShareList.get(i);
String vFlID = encryptShare.getFlID();
byte[] share = encryptShare.getShare().getArray();
byte[] iVecIn = new byte[IVEC_LEN];
AESEncrypt aesEncrypt = new AESEncrypt(cuvKeys.get(vFlID), iVecIn, "CBC");
byte[] decryptShare = aesEncrypt.decrypt(cuvKeys.get(vFlID), share);
int sSize = (int) decryptShare[0];
int bSize = (int) decryptShare[1];
int sIndexLen = (int) decryptShare[2];
int bIndexLen = (int) decryptShare[3];
int sIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4, 4 + sIndexLen));
int bIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4 + sIndexLen, 4 + sIndexLen + bIndexLen));
byte[] sSkUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen, 4 + sIndexLen + bIndexLen + sSize);
byte[] bUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen + sSize, 4 + sIndexLen + bIndexLen + sSize + bSize);
NewArray<byte[]> sSkVu = new NewArray<>();
sSkVu.setSize(sSize);
sSkVu.setArray(sSkUv);
NewArray bVu = new NewArray();
bVu.setSize(bSize);
bVu.setArray(bUv);
decryptShareSecrets.setFlID(vFlID);
decryptShareSecrets.setSSkVu(sSkVu);
decryptShareSecrets.setBVu(bVu);
decryptShareSecrets.setSIndex(sIndex);
decryptShareSecrets.setIndexB(bIndex);
decryptSecretsList.add(decryptShareSecrets);
}
}
}

View File

@ -0,0 +1,63 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher;
import org.bouncycastle.crypto.digests.SHA256Digest;
import org.bouncycastle.crypto.generators.PKCS5S2ParametersGenerator;
import org.bouncycastle.crypto.params.KeyParameter;
import org.bouncycastle.math.ec.rfc7748.X25519;
import java.security.SecureRandom;
import java.util.logging.Logger;
public class KEYAgreement {
private static final Logger LOGGER = Logger.getLogger(KEYAgreement.class.toString());
private static final int PBKDF2_ITERATIONS = 10000;
private static final int SALT_SIZE = 32;
private static final int HASH_BIT_SIZE = 256;
private static final int KEY_LEN = X25519.SCALAR_SIZE;
private SecureRandom random = new SecureRandom();
public byte[] generatePrivateKey() {
byte[] privateKey = new byte[KEY_LEN];
X25519.generatePrivateKey(random, privateKey);
return privateKey;
}
public byte[] generatePublicKey(byte[] privatekey) {
byte[] publicKey = new byte[KEY_LEN];
X25519.generatePublicKey(privatekey, 0, publicKey, 0);
return publicKey;
}
public byte[] keyAgreement(byte[] privatekey, byte[] publicKey) {
byte[] secret = new byte[KEY_LEN];
X25519.calculateAgreement(privatekey, 0, publicKey, 0, secret, 0);
return secret;
}
public byte[] getEncryptedPassword(byte[] password, byte[] salt) {
byte[] saltB = new byte[SALT_SIZE];
PKCS5S2ParametersGenerator gen = new PKCS5S2ParametersGenerator(new SHA256Digest());
gen.init(password, saltB, PBKDF2_ITERATIONS);
byte[] dk = ((KeyParameter) gen.generateDerivedParameters(HASH_BIT_SIZE)).getKey();
return dk;
}
}

View File

@ -0,0 +1,82 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.List;
import java.util.logging.Logger;
public class Random {
/**
* random generate RNG algorithm name
*/
private static final Logger LOGGER = Logger.getLogger(Random.class.toString());
private static final String RNG_ALGORITHM = "SHA1PRNG";
private static final int RANDOM_LEN = 128 / 8;
public void getRandomBytes(byte[] secret) {
try {
SecureRandom secureRandom = SecureRandom.getInstance("SHA1PRNG");
secureRandom.nextBytes(secret);
} catch (NoSuchAlgorithmException e) {
e.printStackTrace();
}
}
public void randomAESCTR(List<Float> noise, int length, byte[] seed) throws Exception {
int intV = Integer.SIZE / 8;
int size = length * intV;
byte[] data = new byte[size];
for (int i = 0; i < size; i++) {
data[i] = 0;
}
byte[] ivec = new byte[RANDOM_LEN];
AESEncrypt aesEncrypt = new AESEncrypt(seed, ivec, "CTR");
byte[] encryptCtr = aesEncrypt.encryptCTR(seed, data);
for (int i = 0; i < length; i++) {
int[] sub = new int[intV];
for (int j = 0; j < 4; j++) {
sub[j] = (int) encryptCtr[i * intV + j] & 0xff;
}
int subI = byte2int(sub, 4);
Float f = Float.valueOf(Float.valueOf(subI) / Integer.MAX_VALUE);
noise.add(f);
}
}
public static int byte2int(int[] data, int n) {
switch (n) {
case 1:
return (int) data[0];
case 2:
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00);
case 3:
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000);
case 4:
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000)
| (data[3] << 24 & 0xff000000);
default:
return 0;
}
}
}

View File

@ -0,0 +1,125 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.Common;
import com.mindspore.flclient.FLClientStatus;
import com.mindspore.flclient.FLCommunication;
import com.mindspore.flclient.FLParameter;
import com.mindspore.flclient.LocalFLParameter;
import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
import mindspore.schema.ClientShare;
import mindspore.schema.ResponseCode;
import java.nio.ByteBuffer;
import java.time.LocalDateTime;
import java.util.List;
import java.util.logging.Logger;
public class ReconstructSecretReq {
private static final Logger LOGGER = Logger.getLogger(ReconstructSecretReq.class.toString());
private FLCommunication flCommunication;
private String nextRequestTime;
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
public String getNextRequestTime() {
return nextRequestTime;
}
public void setNextRequestTime(String nextRequestTime) {
this.nextRequestTime = nextRequestTime;
}
public ReconstructSecretReq() {
flCommunication = FLCommunication.getInstance();
}
public FLClientStatus sendReconstructSecret(List<DecryptShareSecrets> decryptShareSecretsList, List<String> u3ClientList, int iteration) {
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[PairWiseMask] ==============sendReconstructSecret url: " + url + "=============="));
FlatBufferBuilder builder = new FlatBufferBuilder();
int desFlId = builder.createString(localFLParameter.getFlID());
String dateTime = LocalDateTime.now().toString();
int time = builder.createString(dateTime);
int shareSecretsSize = decryptShareSecretsList.size();
if (shareSecretsSize <= 0) {
LOGGER.info(Common.addTag("[PairWiseMask] request failed: the decryptShareSecretsList is null, please waite."));
return FLClientStatus.FAILED;
} else {
int[] decryptShareList = new int[shareSecretsSize];
for (int i = 0; i < shareSecretsSize; i++) {
DecryptShareSecrets decryptShareSecrets = decryptShareSecretsList.get(i);
String srcFlId = decryptShareSecrets.getFlID();
byte[] share;
int index;
if (u3ClientList.contains(srcFlId)) {
share = decryptShareSecrets.getBVu().getArray();
index = decryptShareSecrets.getIndexB();
} else {
share = decryptShareSecrets.getSSkVu().getArray();
index = decryptShareSecrets.getSIndex();
}
int fbsSrcFlId = builder.createString(srcFlId);
int fbsShare = ClientShare.createShareVector(builder, share);
int clientShare = ClientShare.createClientShare(builder, fbsSrcFlId, fbsShare, index);
decryptShareList[i] = clientShare;
}
int reconstructShareSecrets = mindspore.schema.SendReconstructSecret.createReconstructSecretSharesVector(builder, decryptShareList);
int reconstructSecretRoot = mindspore.schema.SendReconstructSecret.createSendReconstructSecret(builder, desFlId, reconstructShareSecrets, iteration, time);
builder.finish(reconstructSecretRoot);
byte[] msg = builder.sizedByteArray();
try {
byte[] responseData = flCommunication.syncRequest(url + "/reconstructSecrets", msg);
ByteBuffer buffer = ByteBuffer.wrap(responseData);
mindspore.schema.ReconstructSecret reconstructSecretRsp = mindspore.schema.ReconstructSecret.getRootAsReconstructSecret(buffer);
FLClientStatus status = judgeSendReconstructSecrets(reconstructSecretRsp);
return status;
} catch (Exception e) {
LOGGER.severe(Common.addTag("[PairWiseMask] un solved error code in reconstruct"));
e.printStackTrace();
return FLClientStatus.FAILED;
}
}
}
public FLClientStatus judgeSendReconstructSecrets(mindspore.schema.ReconstructSecret bufData) {
int retcode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of SendReconstructSecrets**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode));
LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason()));
LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration()));
LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime()));
switch (retcode) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[PairWiseMask] ReconstructSecrets success"));
return FLClientStatus.SUCCESS;
case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] SendReconstructSecrets out of time: need wait and request startFLJob again"));
setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
case (ResponseCode.SystemError):
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in SendReconstructSecrets"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReconstructSecret is invalid: " + retcode));
return FLClientStatus.FAILED;
}
}
}

View File

@ -0,0 +1,136 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher;
import com.mindspore.flclient.Common;
import java.math.BigInteger;
import java.util.Random;
import java.util.logging.Logger;
public class ShareSecrets {
private static final Logger LOGGER = Logger.getLogger(ShareSecrets.class.toString());
public final class SecretShare {
public SecretShare(final int num, final BigInteger share) {
this.num = num;
this.share = share;
}
public int getNum() {
return num;
}
public BigInteger getShare() {
return share;
}
@Override
public String toString() {
return "SecretShare [num=" + num + ", share=" + share + "]";
}
private final int num;
private final BigInteger share;
}
public ShareSecrets(final int k, final int n) {
this.k = k;
this.n = n;
random = new Random();
}
public SecretShare[] split(final byte[] bytes, byte[] primeByte) {
BigInteger secret = BaseUtil.byteArray2BigInteger(bytes);
final int modLength = secret.bitLength() + 1;
prime = BaseUtil.byteArray2BigInteger(primeByte);
final BigInteger[] coeff = new BigInteger[k - 1];
LOGGER.info(Common.addTag("Prime Number: " + prime));
for (int i = 0; i < k - 1; i++) {
coeff[i] = randomZp(prime);
LOGGER.info(Common.addTag("a" + (i + 1) + ": " + coeff[i]));
}
final SecretShare[] shares = new SecretShare[n];
for (int i = 1; i <= n; i++) {
BigInteger accum = secret;
for (int j = 1; j < k; j++) {
final BigInteger t1 = BigInteger.valueOf(i).modPow(BigInteger.valueOf(j), prime);
final BigInteger t2 = coeff[j - 1].multiply(t1).mod(prime);
accum = accum.add(t2).mod(prime);
}
shares[i - 1] = new SecretShare(i, accum);
LOGGER.info(Common.addTag("Share " + shares[i - 1]));
}
return shares;
}
public BigInteger getPrime() {
return prime;
}
public BigInteger combine(final SecretShare[] shares, final byte[] primeByte) {
BigInteger primeNum = BaseUtil.byteArray2BigInteger(primeByte);
BigInteger accum = BigInteger.ZERO;
for (int j = 0; j < k; j++) {
BigInteger num = BigInteger.ONE;
BigInteger den = BigInteger.ONE;
BigInteger tmp;
for (int m = 0; m < k; m++) {
if (j != m) {
num = num.multiply(BigInteger.valueOf(shares[m].getNum())).mod(primeNum);
tmp = BigInteger.valueOf(shares[j].getNum()).multiply(BigInteger.valueOf(-1));
tmp = BigInteger.valueOf(shares[m].getNum()).add(tmp).mod(primeNum);
den = den.multiply(tmp).mod(primeNum);
}
}
final BigInteger value = shares[j].getShare();
tmp = den.modInverse(primeNum);
tmp = tmp.multiply(num).mod(primeNum);
tmp = tmp.multiply(value).mod(primeNum);
accum = accum.add(tmp).mod(primeNum);
LOGGER.info(Common.addTag("value: " + value + ", tmp: " + tmp + ", accum: " + accum));
}
LOGGER.info(Common.addTag("The secret is: " + accum));
return accum;
}
private BigInteger randomZp(final BigInteger p) {
while (true) {
final BigInteger r = new BigInteger(p.bitLength(), random);
if (r.compareTo(BigInteger.ZERO) > 0 && r.compareTo(p) < 0) {
return r;
}
}
}
private BigInteger prime;
private final int k;
private final int n;
private final Random random;
private final int SECRET_MAX_LEN = 32;
}

View File

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher.struct;
public class ClientPublicKey {
private String flID;
private NewArray<byte[]> cPK;
private NewArray<byte[]> sPk;
public String getFlID() {
return flID;
}
public void setFlID(String flID) {
this.flID = flID;
}
public NewArray<byte[]> getCPK() {
return cPK;
}
public void setCPK(NewArray<byte[]> cPK) {
this.cPK = cPK;
}
public NewArray<byte[]> getSPK() {
return sPk;
}
public void setSPK(NewArray<byte[]> sPk) {
this.sPk = sPk;
}
}

View File

@ -0,0 +1,65 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher.struct;
public class DecryptShareSecrets {
private String flID;
private NewArray<byte[]> sSkVu;
private NewArray<byte[]> bVu;
private int sIndex;
private int indexB;
public String getFlID() {
return flID;
}
public void setFlID(String flID) {
this.flID = flID;
}
public NewArray<byte[]> getSSkVu() {
return sSkVu;
}
public void setSSkVu(NewArray<byte[]> sSkVu) {
this.sSkVu = sSkVu;
}
public NewArray<byte[]> getBVu() {
return bVu;
}
public void setBVu(NewArray<byte[]> bVu) {
this.bVu = bVu;
}
public int getSIndex() {
return sIndex;
}
public void setSIndex(int sIndex) {
this.sIndex = sIndex;
}
public int getIndexB() {
return indexB;
}
public void setIndexB(int indexB) {
this.indexB = indexB;
}
}

View File

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher.struct;
public class EncryptShare {
private String flID;
private NewArray<byte[]> share;
public String getFlID() {
return flID;
}
public void setFlID(String flID) {
this.flID = flID;
}
public NewArray<byte[]> getShare() {
return share;
}
public void setShare(NewArray<byte[]> share) {
this.share = share;
}
}

View File

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher.struct;
public class NewArray<T> {
private int size;
private T array;
public int getSize() {
return size;
}
public void setSize(int size) {
this.size = size;
}
public T getArray() {
return array;
}
public void setArray(T array) {
this.array = array;
}
}

View File

@ -0,0 +1,47 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher.struct;
public class ShareSecret {
private String flID;
private NewArray<byte[]> share;
private int index;
public String getFlID() {
return flID;
}
public void setFlID(String flID) {
this.flID = flID;
}
public NewArray<byte[]> getShare() {
return share;
}
public void setShare(NewArray<byte[]> share) {
this.share = share;
}
public int getIndex() {
return index;
}
public void setIndex(int index) {
this.index = index;
}
}