From b9c028ed24b9a524bd751357d415e0f40dbd4e4a Mon Sep 17 00:00:00 2001 From: zhoushan Date: Mon, 12 Jul 2021 20:49:42 +0800 Subject: [PATCH] fix bug of albert task in case PWEcrypt for master --- .../com/mindspore/flclient/FLLiteClient.java | 7 ++++++- .../com/mindspore/flclient/FLParameter.java | 6 ++++-- .../java/com/mindspore/flclient/GetModel.java | 2 +- .../com/mindspore/flclient/SecureProtocol.java | 17 ++++++++++++----- .../java/com/mindspore/flclient/SyncFLJob.java | 4 ++-- .../com/mindspore/flclient/UpdateModel.java | 16 +++++++++++++--- 6 files changed, 38 insertions(+), 14 deletions(-) diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java index edd239158a7..adc5616f2aa 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java @@ -112,7 +112,12 @@ public class FLLiteClient { prime[i] = (byte) cipherPublicParams.prime(i); } LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + minSecretNum)); + if (minSecretNum <= 0) { + LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server is not valid: <=0")); + return -1; + } LOGGER.info(Common.addTag("[Encrypt] the prime from server: " + BaseUtil.byte2HexString(prime))); + break; case DP_ENCRYPT: dpEps = cipherPublicParams.dpEps(); dpDelta = cipherPublicParams.dpDelta(); @@ -233,7 +238,7 @@ public class FLLiteClient { status = FLClientStatus.SUCCESS; retCode = ResponseCode.SUCCEED; if (flParameter.getFlName().equals(ALBERT)) { - LOGGER.info(Common.addTag("[train] train in adbert")); + LOGGER.info(Common.addTag("[train] train in albert")); AlTrainBert alTrainBert = AlTrainBert.getInstance(); int tag = alTrainBert.trainModel(flParameter.getTrainModelPath(), epochs); if (tag == -1) { diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLParameter.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLParameter.java index 7acc9355e5e..771d6f17bc4 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLParameter.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLParameter.java @@ -17,6 +17,8 @@ package com.mindspore.flclient; import java.util.logging.Logger; +import static com.mindspore.flclient.LocalFLParameter.ALBERT; + public class FLParameter { private static final Logger LOGGER = Logger.getLogger(FLParameter.class.toString()); @@ -119,7 +121,7 @@ public class FLParameter { } public String getVocabFile() { - if ("null".equals(vocabFile) && "adbert".equals(flName)) { + if ("null".equals(vocabFile) && ALBERT.equals(flName)) { LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please set it before use")); throw new RuntimeException(); } @@ -137,7 +139,7 @@ public class FLParameter { } public String getIdsFile() { - if ("null".equals(idsFile) && "adbert".equals(flName)) { + if ("null".equals(idsFile) && ALBERT.equals(flName)) { LOGGER.severe(Common.addTag("[flParameter] the parameter of is null, please set it before use")); throw new RuntimeException(); } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/GetModel.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/GetModel.java index e8079f02560..5352c08cd1b 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/GetModel.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/GetModel.java @@ -96,7 +96,7 @@ public class GetModel { private FLClientStatus parseResponseAlbert(ResponseGetModel responseDataBuf) { int fmCount = responseDataBuf.featureMapLength(); if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { - LOGGER.info(Common.addTag("[getModel] into ")); + LOGGER.info(Common.addTag("[getModel] into ")); ArrayList albertFeatureMaps = new ArrayList(); ArrayList inferFeatureMaps = new ArrayList(); for (int i = 0; i < fmCount; i++) { diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java index 6f44180a6c4..d92e6743379 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java @@ -28,6 +28,9 @@ import java.util.Map; import java.util.Random; import java.util.logging.Logger; +import static com.mindspore.flclient.LocalFLParameter.ALBERT; +import static com.mindspore.flclient.LocalFLParameter.LENET; + public class SecureProtocol { private static final Logger LOGGER = Logger.getLogger(SecureProtocol.class.toString()); private FLParameter flParameter = FLParameter.getInstance(); @@ -35,7 +38,7 @@ public class SecureProtocol { private int iteration; private CipherClient cipher; private FLClientStatus status; - private float[] featureMask; + private float[] featureMask = new float[0]; private double dpEps; private double dpDelta; private double dpNormClip; @@ -126,13 +129,17 @@ public class SecureProtocol { } public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize) { + if (featureMask == null || featureMask.length == 0) { + LOGGER.severe("[Encrypt] feature mask is null, please check"); + return new int[0]; + } LOGGER.info("[Encrypt] feature mask size: " + featureMask.length); // get feature map Map map = new HashMap(); - if (flParameter.getFlName().equals("adbert")) { + if (flParameter.getFlName().equals(ALBERT)) { AlTrainBert alTrainBert = AlTrainBert.getInstance(); map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession())); - } else if (flParameter.getFlName().equals("lenet")) { + } else if (flParameter.getFlName().equals(LENET)) { TrainLenet trainLenet = TrainLenet.getInstance(); map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); } @@ -266,10 +273,10 @@ public class SecureProtocol { public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize) { // get feature map Map map = new HashMap(); - if (flParameter.getFlName().equals("adbert")) { + if (flParameter.getFlName().equals(ALBERT)) { AlTrainBert alTrainBert = AlTrainBert.getInstance(); map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession())); - } else if (flParameter.getFlName().equals("lenet")) { + } else if (flParameter.getFlName().equals(LENET)) { TrainLenet trainLenet = TrainLenet.getInstance(); map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java index 122b7c43b1e..283a4fdb220 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java @@ -365,7 +365,7 @@ public class SyncFLJob { SyncFLJob syncFLJob = new SyncFLJob(); if (task.equals("train")) { flParameter.setUseHttps(useHttps); - if (useHttps) { + if (useSSL) { flParameter.setCertPath(certPath); } flParameter.setHostName(ip); @@ -396,7 +396,7 @@ public class SyncFLJob { syncFLJob.modelInference(); } else if (task.equals("getModel")) { flParameter.setUseHttps(useHttps); - if (useHttps) { + if (useSSL) { flParameter.setCertPath(certPath); } flParameter.setHostName(ip); diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java index 4fe4a1de2a7..1c2e79dc772 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java @@ -82,20 +82,30 @@ public class UpdateModel { case PW_ENCRYPT: try { int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize); + if (fmOffsetsPW == null || fmOffsetsPW.length == 0) { + LOGGER.severe("[Encrypt] the return fmOffsetsPW from is null, please check"); + throw new RuntimeException(); + } this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW); LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!")); return this; } catch (Exception e) { - LOGGER.severe("[Encrypt] catch error in maskModel: catch Exception"); + LOGGER.severe("[Encrypt] catch error in maskModel: " + e.getMessage()); + throw new RuntimeException(); } case DP_ENCRYPT: try { - int[] fmOffsetsPW = secureProtocol.dpMaskModel(builder, trainDataSize); - this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW); + int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize); + if (fmOffsetsDP == null || fmOffsetsDP.length == 0) { + LOGGER.severe("[Encrypt] the return fmOffsetsDP from is null, please check"); + throw new RuntimeException(); + } + this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP); LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!")); return this; } catch (Exception e) { LOGGER.severe(Common.addTag("[Encrypt] catch error in maskModel: " + e.getMessage())); + throw new RuntimeException(); } case NOT_ENCRYPT: default: