!19057 Add adaptive dp_norm_clip for federated learning's DP mechanism
Merge pull request !19057 from jxlang910/master
This commit is contained in:
commit
d330eaaba4
|
@ -57,7 +57,8 @@ public class FLLiteClient {
|
|||
private int trainDataSize;
|
||||
private double dpEps = 100;
|
||||
private double dpDelta = 0.01;
|
||||
private double dpNormClip = 2.0;
|
||||
public double dpNormClipFactor = 1.0;
|
||||
public double dpNormClipAdapt = 0.5;
|
||||
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
|
@ -110,10 +111,10 @@ public class FLLiteClient {
|
|||
case DP_ENCRYPT:
|
||||
dpEps = cipherPublicParams.dpEps();
|
||||
dpDelta = cipherPublicParams.dpDelta();
|
||||
dpNormClip = cipherPublicParams.dpNormClip();
|
||||
dpNormClipFactor = cipherPublicParams.dpNormClip();
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpEps> from server: " + dpEps));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpDelta> from server: " + dpDelta));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpNormClip> from server: " + dpNormClip));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpNormClipFactor> from server: " + dpNormClipFactor));
|
||||
break;
|
||||
default:
|
||||
LOGGER.info(Common.addTag("[startFLJob] NotEncrypt, do not set parameter for Encrypt"));
|
||||
|
@ -373,7 +374,7 @@ public class FLLiteClient {
|
|||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
}
|
||||
Map<String, float[]> copyMap = getOldMapCopy(map);
|
||||
curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClip, copyMap);
|
||||
curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClipAdapt, copyMap);
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
if (curStatus != FLClientStatus.SUCCESS) {
|
||||
LOGGER.info(Common.addTag("---Differential privacy init failed---"));
|
||||
|
|
|
@ -24,6 +24,8 @@ import mindspore.schema.ResponseGetModel;
|
|||
import java.nio.ByteBuffer;
|
||||
import java.util.Arrays;
|
||||
import java.util.logging.Logger;
|
||||
import java.util.Map;
|
||||
import java.util.HashMap;
|
||||
|
||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||
import static com.mindspore.flclient.LocalFLParameter.ADBERT;
|
||||
|
@ -34,6 +36,7 @@ public class SyncFLJob {
|
|||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private FLJobResultCallback flJobResultCallback = new FLJobResultCallback();
|
||||
private Map<String, float[]> oldFeatureMap;
|
||||
|
||||
public SyncFLJob() {
|
||||
}
|
||||
|
@ -72,6 +75,9 @@ public class SyncFLJob {
|
|||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] startFLJob succeed, curIteration: " + client.getIteration()));
|
||||
|
||||
// get the feature map before train
|
||||
getOldFeatureMap(client);
|
||||
|
||||
// create mask
|
||||
curStatus = client.getFeatureMask();
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
|
@ -130,7 +136,10 @@ public class SyncFLJob {
|
|||
}
|
||||
LOGGER.info(Common.addTag("[getModel] getModel succeed"));
|
||||
|
||||
//evaluate model after getting model from server
|
||||
// get the feature map after averaging and update dp_norm_clip
|
||||
updateDpNormClip(client);
|
||||
|
||||
// evaluate model after getting model from server
|
||||
if (flParameter.getTestDataset().equals("null")) {
|
||||
LOGGER.info(Common.addTag("[evaluate] the testDataset is null, don't evaluate the combine model"));
|
||||
} else {
|
||||
|
@ -149,6 +158,49 @@ public class SyncFLJob {
|
|||
flJobResultCallback.onFlJobFinished(flParameter.getFlName(), client.getIterations(), client.getRetCode());
|
||||
}
|
||||
|
||||
private void updateDpNormClip(FLLiteClient client) {
|
||||
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel();
|
||||
if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
|
||||
Map<String, float[]> fedFeatureMap = getFeatureMap();
|
||||
float fedWeightUpdateNorm = calWeightUpdateNorm(oldFeatureMap, fedFeatureMap);
|
||||
LOGGER.info(Common.addTag("[DP] L2-norm of weights' average update is: " + fedWeightUpdateNorm));
|
||||
client.dpNormClipAdapt = client.dpNormClipFactor*fedWeightUpdateNorm;
|
||||
LOGGER.info(Common.addTag("[DP] Adaptive dpNormClip is: " + client.dpNormClipAdapt));
|
||||
}
|
||||
}
|
||||
private void getOldFeatureMap(FLLiteClient client) {
|
||||
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel();
|
||||
if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
|
||||
Map<String, float[]> featureMap = getFeatureMap();
|
||||
oldFeatureMap = client.getOldMapCopy(featureMap);
|
||||
}
|
||||
}
|
||||
private float calWeightUpdateNorm(Map<String, float[]> originalData, Map<String, float[]> newData){
|
||||
float updateL2Norm = 0;
|
||||
for (String key: originalData.keySet()) {
|
||||
float[] data=originalData.get(key);
|
||||
float[] dataAfterUpdate = newData.get(key);
|
||||
for (int j = 0; j<data.length; j++) {
|
||||
float updateData = data[j] - dataAfterUpdate[j];
|
||||
updateL2Norm += updateData*updateData;
|
||||
}
|
||||
}
|
||||
updateL2Norm = (float)Math.sqrt(updateL2Norm);
|
||||
return updateL2Norm;
|
||||
}
|
||||
|
||||
private Map<String, float[]> getFeatureMap() {
|
||||
Map<String, float[]> featureMap = new HashMap<>();
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
}
|
||||
return featureMap;
|
||||
}
|
||||
|
||||
public int[] modelInference(String flName, String dataPath, String vocabFile, String idsFile, String modelPath) {
|
||||
int[] labels = new int[0];
|
||||
if (flName.equals(ADBERT)) {
|
||||
|
|
Loading…
Reference in New Issue