fix bug of albert task in case PWEcrypt for master

This commit is contained in:
zhoushan 2021-07-12 20:49:42 +08:00
parent 6122ee8202
commit b9c028ed24
6 changed files with 38 additions and 14 deletions

View File

@ -112,7 +112,12 @@ public class FLLiteClient {
prime[i] = (byte) cipherPublicParams.prime(i);
}
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server: " + minSecretNum));
if (minSecretNum <= 0) {
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server is not valid: <=0"));
return -1;
}
LOGGER.info(Common.addTag("[Encrypt] the prime from server: " + BaseUtil.byte2HexString(prime)));
break;
case DP_ENCRYPT:
dpEps = cipherPublicParams.dpEps();
dpDelta = cipherPublicParams.dpDelta();
@ -233,7 +238,7 @@ public class FLLiteClient {
status = FLClientStatus.SUCCESS;
retCode = ResponseCode.SUCCEED;
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("[train] train in adbert"));
LOGGER.info(Common.addTag("[train] train in albert"));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
int tag = alTrainBert.trainModel(flParameter.getTrainModelPath(), epochs);
if (tag == -1) {

View File

@ -17,6 +17,8 @@ package com.mindspore.flclient;
import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
public class FLParameter {
private static final Logger LOGGER = Logger.getLogger(FLParameter.class.toString());
@ -119,7 +121,7 @@ public class FLParameter {
}
public String getVocabFile() {
if ("null".equals(vocabFile) && "adbert".equals(flName)) {
if ("null".equals(vocabFile) && ALBERT.equals(flName)) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is null, please set it before use"));
throw new RuntimeException();
}
@ -137,7 +139,7 @@ public class FLParameter {
}
public String getIdsFile() {
if ("null".equals(idsFile) && "adbert".equals(flName)) {
if ("null".equals(idsFile) && ALBERT.equals(flName)) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is null, please set it before use"));
throw new RuntimeException();
}

View File

@ -96,7 +96,7 @@ public class GetModel {
private FLClientStatus parseResponseAlbert(ResponseGetModel responseDataBuf) {
int fmCount = responseDataBuf.featureMapLength();
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseAdbert>"));
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {

View File

@ -28,6 +28,9 @@ import java.util.Map;
import java.util.Random;
import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
public class SecureProtocol {
private static final Logger LOGGER = Logger.getLogger(SecureProtocol.class.toString());
private FLParameter flParameter = FLParameter.getInstance();
@ -35,7 +38,7 @@ public class SecureProtocol {
private int iteration;
private CipherClient cipher;
private FLClientStatus status;
private float[] featureMask;
private float[] featureMask = new float[0];
private double dpEps;
private double dpDelta;
private double dpNormClip;
@ -126,13 +129,17 @@ public class SecureProtocol {
}
public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize) {
if (featureMask == null || featureMask.length == 0) {
LOGGER.severe("[Encrypt] feature mask is null, please check");
return new int[0];
}
LOGGER.info("[Encrypt] feature mask size: " + featureMask.length);
// get feature map
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals("adbert")) {
if (flParameter.getFlName().equals(ALBERT)) {
AlTrainBert alTrainBert = AlTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
} else if (flParameter.getFlName().equals("lenet")) {
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
}
@ -266,10 +273,10 @@ public class SecureProtocol {
public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize) {
// get feature map
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals("adbert")) {
if (flParameter.getFlName().equals(ALBERT)) {
AlTrainBert alTrainBert = AlTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
} else if (flParameter.getFlName().equals("lenet")) {
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
}

View File

@ -365,7 +365,7 @@ public class SyncFLJob {
SyncFLJob syncFLJob = new SyncFLJob();
if (task.equals("train")) {
flParameter.setUseHttps(useHttps);
if (useHttps) {
if (useSSL) {
flParameter.setCertPath(certPath);
}
flParameter.setHostName(ip);
@ -396,7 +396,7 @@ public class SyncFLJob {
syncFLJob.modelInference();
} else if (task.equals("getModel")) {
flParameter.setUseHttps(useHttps);
if (useHttps) {
if (useSSL) {
flParameter.setCertPath(certPath);
}
flParameter.setHostName(ip);

View File

@ -82,20 +82,30 @@ public class UpdateModel {
case PW_ENCRYPT:
try {
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize);
if (fmOffsetsPW == null || fmOffsetsPW.length == 0) {
LOGGER.severe("[Encrypt] the return fmOffsetsPW from <secureProtocol.pwMaskModel> is null, please check");
throw new RuntimeException();
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
return this;
} catch (Exception e) {
LOGGER.severe("[Encrypt] catch error in maskModel: catch Exception");
LOGGER.severe("[Encrypt] catch error in maskModel: " + e.getMessage());
throw new RuntimeException();
}
case DP_ENCRYPT:
try {
int[] fmOffsetsPW = secureProtocol.dpMaskModel(builder, trainDataSize);
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize);
if (fmOffsetsDP == null || fmOffsetsDP.length == 0) {
LOGGER.severe("[Encrypt] the return fmOffsetsDP from <secureProtocol.dpMaskModel> is null, please check");
throw new RuntimeException();
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP);
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
return this;
} catch (Exception e) {
LOGGER.severe(Common.addTag("[Encrypt] catch error in maskModel: " + e.getMessage()));
throw new RuntimeException();
}
case NOT_ENCRYPT:
default: