!20104 fix bug of albert task in case PWEcrypt for master
Merge pull request !20104 from zhoushan33/flclient0712_om_master
This commit is contained in:
commit
9924825a70
|
@ -112,7 +112,12 @@ public class FLLiteClient {
|
||||||
prime[i] = (byte) cipherPublicParams.prime(i);
|
prime[i] = (byte) cipherPublicParams.prime(i);
|
||||||
}
|
}
|
||||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server: " + minSecretNum));
|
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server: " + minSecretNum));
|
||||||
|
if (minSecretNum <= 0) {
|
||||||
|
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server is not valid: <=0"));
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
LOGGER.info(Common.addTag("[Encrypt] the prime from server: " + BaseUtil.byte2HexString(prime)));
|
LOGGER.info(Common.addTag("[Encrypt] the prime from server: " + BaseUtil.byte2HexString(prime)));
|
||||||
|
break;
|
||||||
case DP_ENCRYPT:
|
case DP_ENCRYPT:
|
||||||
dpEps = cipherPublicParams.dpEps();
|
dpEps = cipherPublicParams.dpEps();
|
||||||
dpDelta = cipherPublicParams.dpDelta();
|
dpDelta = cipherPublicParams.dpDelta();
|
||||||
|
@ -233,7 +238,7 @@ public class FLLiteClient {
|
||||||
status = FLClientStatus.SUCCESS;
|
status = FLClientStatus.SUCCESS;
|
||||||
retCode = ResponseCode.SUCCEED;
|
retCode = ResponseCode.SUCCEED;
|
||||||
if (flParameter.getFlName().equals(ALBERT)) {
|
if (flParameter.getFlName().equals(ALBERT)) {
|
||||||
LOGGER.info(Common.addTag("[train] train in adbert"));
|
LOGGER.info(Common.addTag("[train] train in albert"));
|
||||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||||
int tag = alTrainBert.trainModel(flParameter.getTrainModelPath(), epochs);
|
int tag = alTrainBert.trainModel(flParameter.getTrainModelPath(), epochs);
|
||||||
if (tag == -1) {
|
if (tag == -1) {
|
||||||
|
|
|
@ -17,6 +17,8 @@ package com.mindspore.flclient;
|
||||||
|
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||||
|
|
||||||
public class FLParameter {
|
public class FLParameter {
|
||||||
private static final Logger LOGGER = Logger.getLogger(FLParameter.class.toString());
|
private static final Logger LOGGER = Logger.getLogger(FLParameter.class.toString());
|
||||||
|
|
||||||
|
@ -119,7 +121,7 @@ public class FLParameter {
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getVocabFile() {
|
public String getVocabFile() {
|
||||||
if ("null".equals(vocabFile) && "adbert".equals(flName)) {
|
if ("null".equals(vocabFile) && ALBERT.equals(flName)) {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is null, please set it before use"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is null, please set it before use"));
|
||||||
throw new RuntimeException();
|
throw new RuntimeException();
|
||||||
}
|
}
|
||||||
|
@ -137,7 +139,7 @@ public class FLParameter {
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getIdsFile() {
|
public String getIdsFile() {
|
||||||
if ("null".equals(idsFile) && "adbert".equals(flName)) {
|
if ("null".equals(idsFile) && ALBERT.equals(flName)) {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is null, please set it before use"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is null, please set it before use"));
|
||||||
throw new RuntimeException();
|
throw new RuntimeException();
|
||||||
}
|
}
|
||||||
|
|
|
@ -96,7 +96,7 @@ public class GetModel {
|
||||||
private FLClientStatus parseResponseAlbert(ResponseGetModel responseDataBuf) {
|
private FLClientStatus parseResponseAlbert(ResponseGetModel responseDataBuf) {
|
||||||
int fmCount = responseDataBuf.featureMapLength();
|
int fmCount = responseDataBuf.featureMapLength();
|
||||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseAdbert>"));
|
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
|
||||||
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
|
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
|
||||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||||
for (int i = 0; i < fmCount; i++) {
|
for (int i = 0; i < fmCount; i++) {
|
||||||
|
|
|
@ -28,6 +28,9 @@ import java.util.Map;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||||
|
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||||
|
|
||||||
public class SecureProtocol {
|
public class SecureProtocol {
|
||||||
private static final Logger LOGGER = Logger.getLogger(SecureProtocol.class.toString());
|
private static final Logger LOGGER = Logger.getLogger(SecureProtocol.class.toString());
|
||||||
private FLParameter flParameter = FLParameter.getInstance();
|
private FLParameter flParameter = FLParameter.getInstance();
|
||||||
|
@ -35,7 +38,7 @@ public class SecureProtocol {
|
||||||
private int iteration;
|
private int iteration;
|
||||||
private CipherClient cipher;
|
private CipherClient cipher;
|
||||||
private FLClientStatus status;
|
private FLClientStatus status;
|
||||||
private float[] featureMask;
|
private float[] featureMask = new float[0];
|
||||||
private double dpEps;
|
private double dpEps;
|
||||||
private double dpDelta;
|
private double dpDelta;
|
||||||
private double dpNormClip;
|
private double dpNormClip;
|
||||||
|
@ -126,13 +129,17 @@ public class SecureProtocol {
|
||||||
}
|
}
|
||||||
|
|
||||||
public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize) {
|
public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize) {
|
||||||
|
if (featureMask == null || featureMask.length == 0) {
|
||||||
|
LOGGER.severe("[Encrypt] feature mask is null, please check");
|
||||||
|
return new int[0];
|
||||||
|
}
|
||||||
LOGGER.info("[Encrypt] feature mask size: " + featureMask.length);
|
LOGGER.info("[Encrypt] feature mask size: " + featureMask.length);
|
||||||
// get feature map
|
// get feature map
|
||||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||||
if (flParameter.getFlName().equals("adbert")) {
|
if (flParameter.getFlName().equals(ALBERT)) {
|
||||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
||||||
} else if (flParameter.getFlName().equals("lenet")) {
|
} else if (flParameter.getFlName().equals(LENET)) {
|
||||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||||
}
|
}
|
||||||
|
@ -266,10 +273,10 @@ public class SecureProtocol {
|
||||||
public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize) {
|
public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize) {
|
||||||
// get feature map
|
// get feature map
|
||||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||||
if (flParameter.getFlName().equals("adbert")) {
|
if (flParameter.getFlName().equals(ALBERT)) {
|
||||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
||||||
} else if (flParameter.getFlName().equals("lenet")) {
|
} else if (flParameter.getFlName().equals(LENET)) {
|
||||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||||
}
|
}
|
||||||
|
|
|
@ -365,7 +365,7 @@ public class SyncFLJob {
|
||||||
SyncFLJob syncFLJob = new SyncFLJob();
|
SyncFLJob syncFLJob = new SyncFLJob();
|
||||||
if (task.equals("train")) {
|
if (task.equals("train")) {
|
||||||
flParameter.setUseHttps(useHttps);
|
flParameter.setUseHttps(useHttps);
|
||||||
if (useHttps) {
|
if (useSSL) {
|
||||||
flParameter.setCertPath(certPath);
|
flParameter.setCertPath(certPath);
|
||||||
}
|
}
|
||||||
flParameter.setHostName(ip);
|
flParameter.setHostName(ip);
|
||||||
|
@ -396,7 +396,7 @@ public class SyncFLJob {
|
||||||
syncFLJob.modelInference();
|
syncFLJob.modelInference();
|
||||||
} else if (task.equals("getModel")) {
|
} else if (task.equals("getModel")) {
|
||||||
flParameter.setUseHttps(useHttps);
|
flParameter.setUseHttps(useHttps);
|
||||||
if (useHttps) {
|
if (useSSL) {
|
||||||
flParameter.setCertPath(certPath);
|
flParameter.setCertPath(certPath);
|
||||||
}
|
}
|
||||||
flParameter.setHostName(ip);
|
flParameter.setHostName(ip);
|
||||||
|
|
|
@ -82,20 +82,30 @@ public class UpdateModel {
|
||||||
case PW_ENCRYPT:
|
case PW_ENCRYPT:
|
||||||
try {
|
try {
|
||||||
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize);
|
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize);
|
||||||
|
if (fmOffsetsPW == null || fmOffsetsPW.length == 0) {
|
||||||
|
LOGGER.severe("[Encrypt] the return fmOffsetsPW from <secureProtocol.pwMaskModel> is null, please check");
|
||||||
|
throw new RuntimeException();
|
||||||
|
}
|
||||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
|
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
|
||||||
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
|
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
|
||||||
return this;
|
return this;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
LOGGER.severe("[Encrypt] catch error in maskModel: catch Exception");
|
LOGGER.severe("[Encrypt] catch error in maskModel: " + e.getMessage());
|
||||||
|
throw new RuntimeException();
|
||||||
}
|
}
|
||||||
case DP_ENCRYPT:
|
case DP_ENCRYPT:
|
||||||
try {
|
try {
|
||||||
int[] fmOffsetsPW = secureProtocol.dpMaskModel(builder, trainDataSize);
|
int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize);
|
||||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
|
if (fmOffsetsDP == null || fmOffsetsDP.length == 0) {
|
||||||
|
LOGGER.severe("[Encrypt] the return fmOffsetsDP from <secureProtocol.dpMaskModel> is null, please check");
|
||||||
|
throw new RuntimeException();
|
||||||
|
}
|
||||||
|
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP);
|
||||||
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
|
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
|
||||||
return this;
|
return this;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
LOGGER.severe(Common.addTag("[Encrypt] catch error in maskModel: " + e.getMessage()));
|
LOGGER.severe(Common.addTag("[Encrypt] catch error in maskModel: " + e.getMessage()));
|
||||||
|
throw new RuntimeException();
|
||||||
}
|
}
|
||||||
case NOT_ENCRYPT:
|
case NOT_ENCRYPT:
|
||||||
default:
|
default:
|
||||||
|
|
Loading…
Reference in New Issue