!20104 fix bug of albert task in case PWEcrypt for master
Merge pull request !20104 from zhoushan33/flclient0712_om_master
This commit is contained in:
commit
9924825a70
|
@ -112,7 +112,12 @@ public class FLLiteClient {
|
|||
prime[i] = (byte) cipherPublicParams.prime(i);
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server: " + minSecretNum));
|
||||
if (minSecretNum <= 0) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server is not valid: <=0"));
|
||||
return -1;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[Encrypt] the prime from server: " + BaseUtil.byte2HexString(prime)));
|
||||
break;
|
||||
case DP_ENCRYPT:
|
||||
dpEps = cipherPublicParams.dpEps();
|
||||
dpDelta = cipherPublicParams.dpDelta();
|
||||
|
@ -233,7 +238,7 @@ public class FLLiteClient {
|
|||
status = FLClientStatus.SUCCESS;
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
LOGGER.info(Common.addTag("[train] train in adbert"));
|
||||
LOGGER.info(Common.addTag("[train] train in albert"));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
int tag = alTrainBert.trainModel(flParameter.getTrainModelPath(), epochs);
|
||||
if (tag == -1) {
|
||||
|
|
|
@ -17,6 +17,8 @@ package com.mindspore.flclient;
|
|||
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
|
||||
public class FLParameter {
|
||||
private static final Logger LOGGER = Logger.getLogger(FLParameter.class.toString());
|
||||
|
||||
|
@ -119,7 +121,7 @@ public class FLParameter {
|
|||
}
|
||||
|
||||
public String getVocabFile() {
|
||||
if ("null".equals(vocabFile) && "adbert".equals(flName)) {
|
||||
if ("null".equals(vocabFile) && ALBERT.equals(flName)) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
|
@ -137,7 +139,7 @@ public class FLParameter {
|
|||
}
|
||||
|
||||
public String getIdsFile() {
|
||||
if ("null".equals(idsFile) && "adbert".equals(flName)) {
|
||||
if ("null".equals(idsFile) && ALBERT.equals(flName)) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
|
|
|
@ -96,7 +96,7 @@ public class GetModel {
|
|||
private FLClientStatus parseResponseAlbert(ResponseGetModel responseDataBuf) {
|
||||
int fmCount = responseDataBuf.featureMapLength();
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseAdbert>"));
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
|
||||
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
|
||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
|
|
|
@ -28,6 +28,9 @@ import java.util.Map;
|
|||
import java.util.Random;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
public class SecureProtocol {
|
||||
private static final Logger LOGGER = Logger.getLogger(SecureProtocol.class.toString());
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
|
@ -35,7 +38,7 @@ public class SecureProtocol {
|
|||
private int iteration;
|
||||
private CipherClient cipher;
|
||||
private FLClientStatus status;
|
||||
private float[] featureMask;
|
||||
private float[] featureMask = new float[0];
|
||||
private double dpEps;
|
||||
private double dpDelta;
|
||||
private double dpNormClip;
|
||||
|
@ -126,13 +129,17 @@ public class SecureProtocol {
|
|||
}
|
||||
|
||||
public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize) {
|
||||
if (featureMask == null || featureMask.length == 0) {
|
||||
LOGGER.severe("[Encrypt] feature mask is null, please check");
|
||||
return new int[0];
|
||||
}
|
||||
LOGGER.info("[Encrypt] feature mask size: " + featureMask.length);
|
||||
// get feature map
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
if (flParameter.getFlName().equals("adbert")) {
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
||||
} else if (flParameter.getFlName().equals("lenet")) {
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
}
|
||||
|
@ -266,10 +273,10 @@ public class SecureProtocol {
|
|||
public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize) {
|
||||
// get feature map
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
if (flParameter.getFlName().equals("adbert")) {
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
||||
} else if (flParameter.getFlName().equals("lenet")) {
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
}
|
||||
|
|
|
@ -365,7 +365,7 @@ public class SyncFLJob {
|
|||
SyncFLJob syncFLJob = new SyncFLJob();
|
||||
if (task.equals("train")) {
|
||||
flParameter.setUseHttps(useHttps);
|
||||
if (useHttps) {
|
||||
if (useSSL) {
|
||||
flParameter.setCertPath(certPath);
|
||||
}
|
||||
flParameter.setHostName(ip);
|
||||
|
@ -396,7 +396,7 @@ public class SyncFLJob {
|
|||
syncFLJob.modelInference();
|
||||
} else if (task.equals("getModel")) {
|
||||
flParameter.setUseHttps(useHttps);
|
||||
if (useHttps) {
|
||||
if (useSSL) {
|
||||
flParameter.setCertPath(certPath);
|
||||
}
|
||||
flParameter.setHostName(ip);
|
||||
|
|
|
@ -82,20 +82,30 @@ public class UpdateModel {
|
|||
case PW_ENCRYPT:
|
||||
try {
|
||||
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize);
|
||||
if (fmOffsetsPW == null || fmOffsetsPW.length == 0) {
|
||||
LOGGER.severe("[Encrypt] the return fmOffsetsPW from <secureProtocol.pwMaskModel> is null, please check");
|
||||
throw new RuntimeException();
|
||||
}
|
||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
|
||||
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
|
||||
return this;
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe("[Encrypt] catch error in maskModel: catch Exception");
|
||||
LOGGER.severe("[Encrypt] catch error in maskModel: " + e.getMessage());
|
||||
throw new RuntimeException();
|
||||
}
|
||||
case DP_ENCRYPT:
|
||||
try {
|
||||
int[] fmOffsetsPW = secureProtocol.dpMaskModel(builder, trainDataSize);
|
||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
|
||||
int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize);
|
||||
if (fmOffsetsDP == null || fmOffsetsDP.length == 0) {
|
||||
LOGGER.severe("[Encrypt] the return fmOffsetsDP from <secureProtocol.dpMaskModel> is null, please check");
|
||||
throw new RuntimeException();
|
||||
}
|
||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP);
|
||||
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
|
||||
return this;
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe(Common.addTag("[Encrypt] catch error in maskModel: " + e.getMessage()));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
case NOT_ENCRYPT:
|
||||
default:
|
||||
|
|
Loading…
Reference in New Issue