!20695 add safeMod code for PWEcrypt in master

Merge pull request !20695 from zhoushan33/flclient0722_om_master
This commit is contained in:
i-robot 2021-07-23 08:04:27 +00:00 committed by Gitee
commit a0644cf073
11 changed files with 101 additions and 40 deletions

View File

@ -261,7 +261,6 @@ public class CipherClient {
public FLClientStatus requestExchangeKeys() {
LOGGER.info(Common.addTag("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() + "=============="));
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[PairWiseMask] ==============requestExchangeKeys url: " + url + "=============="));
genDHKeyPairs();
byte[] cPK = cKey.get(0);
byte[] sPK = sKey.get(0);
@ -276,6 +275,12 @@ public class CipherClient {
byte[] msg = fbBuilder.sizedByteArray();
try {
byte[] responseData = flCommunication.syncRequest(url + "/exchangeKeys", msg);
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[requestExchangeKeys] The cluster is in safemode, need wait some time and request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);
ResponseExchangeKeys responseExchangeKeys = ResponseExchangeKeys.getRootAsResponseExchangeKeys(buffer);
FLClientStatus status = judgeRequestExchangeKeys(responseExchangeKeys);
@ -313,7 +318,6 @@ public class CipherClient {
public FLClientStatus getExchangeKeys() {
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[PairWiseMask] ==============getExchangeKeys url: " + url + "=============="));
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
int id = fbBuilder.createString(localFLParameter.getFlID());
String dateTime = LocalDateTime.now().toString();
@ -323,6 +327,12 @@ public class CipherClient {
byte[] msg = fbBuilder.sizedByteArray();
try {
byte[] responseData = flCommunication.syncRequest(url + "/getKeys", msg);
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[getExchangeKeys] The cluster is in safemode, need wait some time and request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);
ReturnExchangeKeys returnExchangeKeys = ReturnExchangeKeys.getRootAsReturnExchangeKeys(buffer);
FLClientStatus status = judgeGetExchangeKeys(returnExchangeKeys);
@ -378,7 +388,6 @@ public class CipherClient {
public FLClientStatus requestShareSecrets() throws Exception {
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[PairWiseMask] ==============requestShareSecrets url: " + url + "=============="));
genIndividualSecret();
genEncryptExchangedKeys();
encryptShares();
@ -410,6 +419,12 @@ public class CipherClient {
byte[] msg = fbBuilder.sizedByteArray();
try {
byte[] responseData = flCommunication.syncRequest(url + "/shareSecrets", msg);
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[requestShareSecrets] The cluster is in safemode, need wait some time and request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);
ResponseShareSecrets responseShareSecrets = ResponseShareSecrets.getRootAsResponseShareSecrets(buffer);
FLClientStatus status = judgeRequestShareSecrets(responseShareSecrets);
@ -448,7 +463,6 @@ public class CipherClient {
public FLClientStatus getShareSecrets() {
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[PairWiseMask] ==============getShareSecrets url: " + url + "=============="));
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
int id = fbBuilder.createString(localFLParameter.getFlID());
String dateTime = LocalDateTime.now().toString();
@ -458,6 +472,12 @@ public class CipherClient {
byte[] msg = fbBuilder.sizedByteArray();
try {
byte[] responseData = flCommunication.syncRequest(url + "/getSecrets", msg);
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[getShareSecrets] The cluster is in safemode, need wait some time and request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);
ReturnShareSecrets returnShareSecrets = ReturnShareSecrets.getRootAsReturnShareSecrets(buffer);
FLClientStatus status = judgeGetShareSecrets(returnShareSecrets);

View File

@ -171,7 +171,6 @@ public class FLLiteClient {
public FLClientStatus startFLJob() {
LOGGER.info(Common.addTag("[startFLJob] ====================================Verify server===================================="));
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[startFLJob] ==============startFLJob url: " + url + "=============="));
StartFLJob startFLJob = StartFLJob.getInstance();
Date date = new Date();
long time = date.getTime();
@ -261,7 +260,6 @@ public class FLLiteClient {
public FLClientStatus updateModel() {
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[updateModel] ==============updateModel url: " + url + "=============="));
UpdateModel updateModelBuf = UpdateModel.getInstance();
byte[] updateModelBuffer = updateModelBuf.getRequestUpdateFLJob(iteration, secureProtocol, trainDataSize);
if (updateModelBuf.getStatus() == FLClientStatus.FAILED) {
@ -299,7 +297,6 @@ public class FLLiteClient {
public FLClientStatus getModel() {
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[getModel] ===========getModel url: " + url + "=============="));
GetModel getModelBuf = GetModel.getInstance();
byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), iteration);
try {
@ -502,6 +499,10 @@ public class FLLiteClient {
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath + " dataSize: " + +dataSize + " vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile()));
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
if (dataPath.split(",").length < 2) {
LOGGER.info(Common.addTag("[set input] the set dataPath for lenet is not valid, should be the format of <data.bin,label.bin>"));
return -1;
}
dataSize = trainLenet.initDataSet(dataPath.split(",")[0], dataPath.split(",")[1]);
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath.split(",")[0] + " dataSize: " + +dataSize + " labelPath: " + dataPath.split(",")[1]));
}

View File

@ -51,6 +51,8 @@ public class FLParameter {
private static volatile FLParameter flParameter;
private FLParameter() {}
public static FLParameter getInstance() {
FLParameter localRef = flParameter;
if (localRef == null) {
@ -290,8 +292,8 @@ public class FLParameter {
}
public int getServerNum() {
if (serverNum == 0) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <serverNum> is zero, please set it before use"));
if (serverNum <= 0) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <serverNum> is <= 0, it should be > 0, please set it before use"));
throw new RuntimeException();
}
return serverNum;

View File

@ -72,7 +72,7 @@ public class GetModel {
}
private static final Logger LOGGER = Logger.getLogger(GetModel.class.toString());
private static GetModel getModel;
private static volatile GetModel getModel;
private GetModel() {
}
@ -81,10 +81,16 @@ public class GetModel {
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
public static GetModel getInstance() {
if (getModel == null) {
getModel = new GetModel();
GetModel localRef = getModel;
if (localRef == null) {
synchronized (GetModel.class) {
localRef = getModel;
if (localRef == null) {
getModel = localRef = new GetModel();
}
}
}
return getModel;
return localRef;
}
public byte[] getRequestGetModel(String name, int iteration) {

View File

@ -42,7 +42,7 @@ public class LocalFLParameter {
Common.setAlbertWeightName(albertWeightName);
}
public static synchronized LocalFLParameter getInstance() {
public static LocalFLParameter getInstance() {
LocalFLParameter localRef = localFLParameter;
if (localRef == null) {
synchronized (LocalFLParameter.class) {

View File

@ -33,13 +33,13 @@ import java.security.cert.X509Certificate;
import java.util.logging.Logger;
public class SSLSocketFactoryTools {
private static final Logger logger = Logger.getLogger(SSLSocketFactory.class.toString());
private static final Logger LOGGER = Logger.getLogger(SSLSocketFactory.class.toString());
private FLParameter flParameter = FLParameter.getInstance();
private X509Certificate x509Certificate;
private SSLSocketFactory sslSocketFactory;
private SSLContext sslContext;
private MyTrustManager myTrustManager;
private static SSLSocketFactoryTools instance;
private static volatile SSLSocketFactoryTools sslSocketFactoryTools;
private SSLSocketFactoryTools() {
initSslSocketFactory();
@ -56,15 +56,21 @@ public class SSLSocketFactoryTools {
sslSocketFactory = sslContext.getSocketFactory();
} catch (Exception e) {
logger.severe(Common.addTag("[SSLSocketFactoryTools]catch Exception in initSslSocketFactory: " + e.getMessage()));
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools]catch Exception in initSslSocketFactory: " + e.getMessage()));
}
}
public static SSLSocketFactoryTools getInstance() {
if (instance == null) {
instance = new SSLSocketFactoryTools();
SSLSocketFactoryTools localRef = sslSocketFactoryTools;
if (localRef == null) {
synchronized (SSLSocketFactoryTools.class) {
localRef = sslSocketFactoryTools;
if (localRef == null) {
sslSocketFactoryTools = localRef = new SSLSocketFactoryTools();
}
}
}
return instance;
return localRef;
}
public X509Certificate readCert(String assetName) {
@ -72,7 +78,7 @@ public class SSLSocketFactoryTools {
try {
inputStream = new FileInputStream(assetName);
} catch (Exception e) {
logger.severe(Common.addTag("[SSLSocketFactoryTools] catch Exception of read inputStream in readCert: " + e.getMessage()));
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch Exception of read inputStream in readCert: " + e.getMessage()));
return null;
}
X509Certificate cert = null;
@ -80,7 +86,7 @@ public class SSLSocketFactoryTools {
CertificateFactory cf = CertificateFactory.getInstance("X.509");
cert = (X509Certificate) cf.generateCertificate(inputStream);
} catch (Exception e) {
logger.severe(Common.addTag("[SSLSocketFactoryTools] catch Exception of creating CertificateFactory in readCert: " + e.getMessage()));
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch Exception of creating CertificateFactory in readCert: " + e.getMessage()));
} finally {
try {
if (inputStream != null) {
@ -128,15 +134,19 @@ public class SSLSocketFactoryTools {
try {
cert.verify(((X509Certificate) this.cert).getPublicKey());
} catch (NoSuchAlgorithmException e) {
logger.severe(Common.addTag("[SSLSocketFactoryTools] catch NoSuchAlgorithmException in checkServerTrusted: " + e.getMessage()));
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch NoSuchAlgorithmException in checkServerTrusted: " + e.getMessage()));
throw new RuntimeException();
} catch (InvalidKeyException e) {
logger.severe(Common.addTag("[SSLSocketFactoryTools] catch InvalidKeyException in checkServerTrusted: " + e.getMessage()));
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch InvalidKeyException in checkServerTrusted: " + e.getMessage()));
throw new RuntimeException();
} catch (NoSuchProviderException e) {
logger.severe(Common.addTag("[SSLSocketFactoryTools] catch NoSuchProviderException in checkServerTrusted: " + e.getMessage()));
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch NoSuchProviderException in checkServerTrusted: " + e.getMessage()));
throw new RuntimeException();
} catch (SignatureException e) {
logger.severe(Common.addTag("[SSLSocketFactoryTools] catch SignatureException in checkServerTrusted: " + e.getMessage()));
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch SignatureException in checkServerTrusted: " + e.getMessage()));
throw new RuntimeException();
}
logger.info(Common.addTag("checkServerTrusted success!"));
LOGGER.info(Common.addTag("checkServerTrusted success!"));
}
}
@ -149,8 +159,8 @@ public class SSLSocketFactoryTools {
private final HostnameVerifier hostnameVerifier = new HostnameVerifier() {
@Override
public boolean verify(String hostname, SSLSession session) {
logger.info(Common.addTag("[SSLSocketFactoryTools] server hostname: " + flParameter.getHostName()));
logger.info(Common.addTag("[SSLSocketFactoryTools] client request hostname: " + hostname));
LOGGER.info(Common.addTag("[SSLSocketFactoryTools] server hostname: " + flParameter.getHostName()));
LOGGER.info(Common.addTag("[SSLSocketFactoryTools] client request hostname: " + hostname));
return hostname.equals(flParameter.getHostName());
}
};

View File

@ -86,7 +86,7 @@ public class StartFLJob {
}
}
private static StartFLJob startFLJob;
private static volatile StartFLJob startFLJob;
private FLClientStatus status;
@ -101,10 +101,16 @@ public class StartFLJob {
}
public static StartFLJob getInstance() {
if (startFLJob == null) {
startFLJob = new StartFLJob();
StartFLJob localRef = startFLJob;
if (localRef == null) {
synchronized (StartFLJob.class) {
localRef = startFLJob;
if (localRef == null) {
startFLJob = localRef = new StartFLJob();
}
}
}
return startFLJob;
return localRef;
}
public String getNextRequestTime() {
@ -218,7 +224,7 @@ public class StartFLJob {
}
public FLClientStatus doResponse(ResponseFLJob flJob) {
LOGGER.info(Common.addTag("[startFLJob] return code: " + flJob.retcode()));
LOGGER.info(Common.addTag("[startFLJob] return retCode: " + flJob.retcode()));
LOGGER.info(Common.addTag("[startFLJob] reason: " + flJob.reason()));
LOGGER.info(Common.addTag("[startFLJob] iteration: " + flJob.iteration()));
LOGGER.info(Common.addTag("[startFLJob] is selected: " + flJob.isSelected()));
@ -246,7 +252,7 @@ public class StartFLJob {
LOGGER.info(Common.addTag("[startFLJob] catch RequestError or SystemError"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[startFLJob] the return <retcode> from server is invalid: " + retCode));
LOGGER.severe(Common.addTag("[startFLJob] the return <retCode> from server is invalid: " + retCode));
return FLClientStatus.FAILED;
}
}

View File

@ -264,13 +264,16 @@ public class SyncFLJob {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
return FLClientStatus.FAILED;
}
flParameter.setUseSSL(flParameter.isUseSSL());
FLCommunication flCommunication = FLCommunication.getInstance();
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[getModel] ===========getModel url: " + url + "=============="));
GetModel getModelBuf = GetModel.getInstance();
byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), 0);
byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[getModel] The cluster is in safemode, need wait some time and request again"));
status = FLClientStatus.WAIT;
return status;
}
LOGGER.info(Common.addTag("[getModel] get model request success"));
ByteBuffer debugBuffer = ByteBuffer.wrap(message);
ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer);

View File

@ -170,7 +170,7 @@ public class UpdateModel {
private UpdateModel() {
}
public static synchronized UpdateModel getInstance() {
public static UpdateModel getInstance() {
UpdateModel localRef = updateModel;
if (localRef == null) {
synchronized (UpdateModel.class) {

View File

@ -36,6 +36,7 @@ import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.IVEC_LEN;
public class ClientListReq {
@ -65,7 +66,6 @@ public class ClientListReq {
public FLClientStatus getClientList(int iteration, List<String> u3ClientList, List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[PairWiseMask] ==============getClientList url: " + url + "=============="));
FlatBufferBuilder builder = new FlatBufferBuilder();
int id = builder.createString(localFLParameter.getFlID());
String dateTime = LocalDateTime.now().toString();
@ -75,6 +75,12 @@ public class ClientListReq {
byte[] msg = builder.sizedByteArray();
try {
byte[] responseData = flCommunication.syncRequest(url + "/getClientList", msg);
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[getClientList] The cluster is in safemode, need wait some time and request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);
ReturnClientList clientListRsp = ReturnClientList.getRootAsReturnClientList(buffer);
FLClientStatus status = judgeGetClientList(clientListRsp, u3ClientList, decryptSecretsList, returnShareList, cuvKeys);

View File

@ -31,6 +31,8 @@ import java.time.LocalDateTime;
import java.util.List;
import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
public class ReconstructSecretReq {
private static final Logger LOGGER = Logger.getLogger(ReconstructSecretReq.class.toString());
private FLCommunication flCommunication;
@ -57,7 +59,6 @@ public class ReconstructSecretReq {
public FLClientStatus sendReconstructSecret(List<DecryptShareSecrets> decryptShareSecretsList, List<String> u3ClientList, int iteration) {
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[PairWiseMask] ==============sendReconstructSecret url: " + url + "=============="));
FlatBufferBuilder builder = new FlatBufferBuilder();
int desFlId = builder.createString(localFLParameter.getFlID());
String dateTime = LocalDateTime.now().toString();
@ -91,6 +92,12 @@ public class ReconstructSecretReq {
byte[] msg = builder.sizedByteArray();
try {
byte[] responseData = flCommunication.syncRequest(url + "/reconstructSecrets", msg);
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[sendReconstructSecret] The cluster is in safemode, need wait some time and request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);
mindspore.schema.ReconstructSecret reconstructSecretRsp = mindspore.schema.ReconstructSecret.getRootAsReconstructSecret(buffer);
FLClientStatus status = judgeSendReconstructSecrets(reconstructSecretRsp);