!20104 fix bug of albert task in case PWEcrypt for master

Merge pull request !20104 from zhoushan33/flclient0712_om_master
This commit is contained in:
i-robot 2021-07-13 11:56:00 +00:00 committed by Gitee
commit 9924825a70
6 changed files with 38 additions and 14 deletions

View File

@ -112,7 +112,12 @@ public class FLLiteClient {
prime[i] = (byte) cipherPublicParams.prime(i); prime[i] = (byte) cipherPublicParams.prime(i);
} }
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server: " + minSecretNum)); LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server: " + minSecretNum));
if (minSecretNum <= 0) {
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server is not valid: <=0"));
return -1;
}
LOGGER.info(Common.addTag("[Encrypt] the prime from server: " + BaseUtil.byte2HexString(prime))); LOGGER.info(Common.addTag("[Encrypt] the prime from server: " + BaseUtil.byte2HexString(prime)));
break;
case DP_ENCRYPT: case DP_ENCRYPT:
dpEps = cipherPublicParams.dpEps(); dpEps = cipherPublicParams.dpEps();
dpDelta = cipherPublicParams.dpDelta(); dpDelta = cipherPublicParams.dpDelta();
@ -233,7 +238,7 @@ public class FLLiteClient {
status = FLClientStatus.SUCCESS; status = FLClientStatus.SUCCESS;
retCode = ResponseCode.SUCCEED; retCode = ResponseCode.SUCCEED;
if (flParameter.getFlName().equals(ALBERT)) { if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("[train] train in adbert")); LOGGER.info(Common.addTag("[train] train in albert"));
AlTrainBert alTrainBert = AlTrainBert.getInstance(); AlTrainBert alTrainBert = AlTrainBert.getInstance();
int tag = alTrainBert.trainModel(flParameter.getTrainModelPath(), epochs); int tag = alTrainBert.trainModel(flParameter.getTrainModelPath(), epochs);
if (tag == -1) { if (tag == -1) {

View File

@ -17,6 +17,8 @@ package com.mindspore.flclient;
import java.util.logging.Logger; import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
public class FLParameter { public class FLParameter {
private static final Logger LOGGER = Logger.getLogger(FLParameter.class.toString()); private static final Logger LOGGER = Logger.getLogger(FLParameter.class.toString());
@ -119,7 +121,7 @@ public class FLParameter {
} }
public String getVocabFile() { public String getVocabFile() {
if ("null".equals(vocabFile) && "adbert".equals(flName)) { if ("null".equals(vocabFile) && ALBERT.equals(flName)) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is null, please set it before use")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is null, please set it before use"));
throw new RuntimeException(); throw new RuntimeException();
} }
@ -137,7 +139,7 @@ public class FLParameter {
} }
public String getIdsFile() { public String getIdsFile() {
if ("null".equals(idsFile) && "adbert".equals(flName)) { if ("null".equals(idsFile) && ALBERT.equals(flName)) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is null, please set it before use")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is null, please set it before use"));
throw new RuntimeException(); throw new RuntimeException();
} }

View File

@ -96,7 +96,7 @@ public class GetModel {
private FLClientStatus parseResponseAlbert(ResponseGetModel responseDataBuf) { private FLClientStatus parseResponseAlbert(ResponseGetModel responseDataBuf) {
int fmCount = responseDataBuf.featureMapLength(); int fmCount = responseDataBuf.featureMapLength();
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseAdbert>")); LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>(); ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>(); ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) { for (int i = 0; i < fmCount; i++) {

View File

@ -28,6 +28,9 @@ import java.util.Map;
import java.util.Random; import java.util.Random;
import java.util.logging.Logger; import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
public class SecureProtocol { public class SecureProtocol {
private static final Logger LOGGER = Logger.getLogger(SecureProtocol.class.toString()); private static final Logger LOGGER = Logger.getLogger(SecureProtocol.class.toString());
private FLParameter flParameter = FLParameter.getInstance(); private FLParameter flParameter = FLParameter.getInstance();
@ -35,7 +38,7 @@ public class SecureProtocol {
private int iteration; private int iteration;
private CipherClient cipher; private CipherClient cipher;
private FLClientStatus status; private FLClientStatus status;
private float[] featureMask; private float[] featureMask = new float[0];
private double dpEps; private double dpEps;
private double dpDelta; private double dpDelta;
private double dpNormClip; private double dpNormClip;
@ -126,13 +129,17 @@ public class SecureProtocol {
} }
public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize) { public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize) {
if (featureMask == null || featureMask.length == 0) {
LOGGER.severe("[Encrypt] feature mask is null, please check");
return new int[0];
}
LOGGER.info("[Encrypt] feature mask size: " + featureMask.length); LOGGER.info("[Encrypt] feature mask size: " + featureMask.length);
// get feature map // get feature map
Map<String, float[]> map = new HashMap<String, float[]>(); Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals("adbert")) { if (flParameter.getFlName().equals(ALBERT)) {
AlTrainBert alTrainBert = AlTrainBert.getInstance(); AlTrainBert alTrainBert = AlTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession())); map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
} else if (flParameter.getFlName().equals("lenet")) { } else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance(); TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
} }
@ -266,10 +273,10 @@ public class SecureProtocol {
public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize) { public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize) {
// get feature map // get feature map
Map<String, float[]> map = new HashMap<String, float[]>(); Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals("adbert")) { if (flParameter.getFlName().equals(ALBERT)) {
AlTrainBert alTrainBert = AlTrainBert.getInstance(); AlTrainBert alTrainBert = AlTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession())); map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
} else if (flParameter.getFlName().equals("lenet")) { } else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance(); TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
} }

View File

@ -365,7 +365,7 @@ public class SyncFLJob {
SyncFLJob syncFLJob = new SyncFLJob(); SyncFLJob syncFLJob = new SyncFLJob();
if (task.equals("train")) { if (task.equals("train")) {
flParameter.setUseHttps(useHttps); flParameter.setUseHttps(useHttps);
if (useHttps) { if (useSSL) {
flParameter.setCertPath(certPath); flParameter.setCertPath(certPath);
} }
flParameter.setHostName(ip); flParameter.setHostName(ip);
@ -396,7 +396,7 @@ public class SyncFLJob {
syncFLJob.modelInference(); syncFLJob.modelInference();
} else if (task.equals("getModel")) { } else if (task.equals("getModel")) {
flParameter.setUseHttps(useHttps); flParameter.setUseHttps(useHttps);
if (useHttps) { if (useSSL) {
flParameter.setCertPath(certPath); flParameter.setCertPath(certPath);
} }
flParameter.setHostName(ip); flParameter.setHostName(ip);

View File

@ -82,20 +82,30 @@ public class UpdateModel {
case PW_ENCRYPT: case PW_ENCRYPT:
try { try {
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize); int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize);
if (fmOffsetsPW == null || fmOffsetsPW.length == 0) {
LOGGER.severe("[Encrypt] the return fmOffsetsPW from <secureProtocol.pwMaskModel> is null, please check");
throw new RuntimeException();
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW); this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!")); LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
return this; return this;
} catch (Exception e) { } catch (Exception e) {
LOGGER.severe("[Encrypt] catch error in maskModel: catch Exception"); LOGGER.severe("[Encrypt] catch error in maskModel: " + e.getMessage());
throw new RuntimeException();
} }
case DP_ENCRYPT: case DP_ENCRYPT:
try { try {
int[] fmOffsetsPW = secureProtocol.dpMaskModel(builder, trainDataSize); int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize);
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW); if (fmOffsetsDP == null || fmOffsetsDP.length == 0) {
LOGGER.severe("[Encrypt] the return fmOffsetsDP from <secureProtocol.dpMaskModel> is null, please check");
throw new RuntimeException();
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP);
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!")); LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
return this; return this;
} catch (Exception e) { } catch (Exception e) {
LOGGER.severe(Common.addTag("[Encrypt] catch error in maskModel: " + e.getMessage())); LOGGER.severe(Common.addTag("[Encrypt] catch error in maskModel: " + e.getMessage()));
throw new RuntimeException();
} }
case NOT_ENCRYPT: case NOT_ENCRYPT:
default: default: