!19057 Add adaptive dp_norm_clip for federated learning's DP mechanism

Merge pull request !19057 from jxlang910/master
This commit is contained in:
i-robot 2021-06-29 07:30:07 +00:00 committed by Gitee
commit d330eaaba4
2 changed files with 58 additions and 5 deletions

View File

@ -57,7 +57,8 @@ public class FLLiteClient {
private int trainDataSize;
private double dpEps = 100;
private double dpDelta = 0.01;
private double dpNormClip = 2.0;
public double dpNormClipFactor = 1.0;
public double dpNormClipAdapt = 0.5;
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
@ -110,10 +111,10 @@ public class FLLiteClient {
case DP_ENCRYPT:
dpEps = cipherPublicParams.dpEps();
dpDelta = cipherPublicParams.dpDelta();
dpNormClip = cipherPublicParams.dpNormClip();
dpNormClipFactor = cipherPublicParams.dpNormClip();
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpEps> from server: " + dpEps));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpDelta> from server: " + dpDelta));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpNormClip> from server: " + dpNormClip));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpNormClipFactor> from server: " + dpNormClipFactor));
break;
default:
LOGGER.info(Common.addTag("[startFLJob] NotEncrypt, do not set parameter for Encrypt"));
@ -373,7 +374,7 @@ public class FLLiteClient {
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
}
Map<String, float[]> copyMap = getOldMapCopy(map);
curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClip, copyMap);
curStatus = secureProtocol.setDPParameter(iteration, dpEps, dpDelta, dpNormClipAdapt, copyMap);
retCode = ResponseCode.SUCCEED;
if (curStatus != FLClientStatus.SUCCESS) {
LOGGER.info(Common.addTag("---Differential privacy init failed---"));

View File

@ -24,6 +24,8 @@ import mindspore.schema.ResponseGetModel;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.logging.Logger;
import java.util.Map;
import java.util.HashMap;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.ADBERT;
@ -34,6 +36,7 @@ public class SyncFLJob {
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private FLJobResultCallback flJobResultCallback = new FLJobResultCallback();
private Map<String, float[]> oldFeatureMap;
public SyncFLJob() {
}
@ -72,6 +75,9 @@ public class SyncFLJob {
}
LOGGER.info(Common.addTag("[startFLJob] startFLJob succeed, curIteration: " + client.getIteration()));
// get the feature map before train
getOldFeatureMap(client);
// create mask
curStatus = client.getFeatureMask();
if (curStatus == FLClientStatus.RESTART) {
@ -130,7 +136,10 @@ public class SyncFLJob {
}
LOGGER.info(Common.addTag("[getModel] getModel succeed"));
//evaluate model after getting model from server
// get the feature map after averaging and update dp_norm_clip
updateDpNormClip(client);
// evaluate model after getting model from server
if (flParameter.getTestDataset().equals("null")) {
LOGGER.info(Common.addTag("[evaluate] the testDataset is null, don't evaluate the combine model"));
} else {
@ -149,6 +158,49 @@ public class SyncFLJob {
flJobResultCallback.onFlJobFinished(flParameter.getFlName(), client.getIterations(), client.getRetCode());
}
private void updateDpNormClip(FLLiteClient client) {
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel();
if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
Map<String, float[]> fedFeatureMap = getFeatureMap();
float fedWeightUpdateNorm = calWeightUpdateNorm(oldFeatureMap, fedFeatureMap);
LOGGER.info(Common.addTag("[DP] L2-norm of weights' average update is: " + fedWeightUpdateNorm));
client.dpNormClipAdapt = client.dpNormClipFactor*fedWeightUpdateNorm;
LOGGER.info(Common.addTag("[DP] Adaptive dpNormClip is: " + client.dpNormClipAdapt));
}
}
private void getOldFeatureMap(FLLiteClient client) {
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel();
if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
Map<String, float[]> featureMap = getFeatureMap();
oldFeatureMap = client.getOldMapCopy(featureMap);
}
}
private float calWeightUpdateNorm(Map<String, float[]> originalData, Map<String, float[]> newData){
float updateL2Norm = 0;
for (String key: originalData.keySet()) {
float[] data=originalData.get(key);
float[] dataAfterUpdate = newData.get(key);
for (int j = 0; j<data.length; j++) {
float updateData = data[j] - dataAfterUpdate[j];
updateL2Norm += updateData*updateData;
}
}
updateL2Norm = (float)Math.sqrt(updateL2Norm);
return updateL2Norm;
}
private Map<String, float[]> getFeatureMap() {
Map<String, float[]> featureMap = new HashMap<>();
if (flParameter.getFlName().equals(ADBERT)) {
AdTrainBert adTrainBert = AdTrainBert.getInstance();
featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
}
return featureMap;
}
public int[] modelInference(String flName, String dataPath, String vocabFile, String idsFile, String modelPath) {
int[] labels = new int[0];
if (flName.equals(ADBERT)) {