From 0c985c070373dc55fcaba1fe788f198c5d1e4e92 Mon Sep 17 00:00:00 2001 From: zhoushan Date: Wed, 7 Jul 2021 16:11:01 +0800 Subject: [PATCH] add code of albert for fl_client --- .../java/com/mindspore/flclient/Common.java | 2 +- .../com/mindspore/flclient/FLLiteClient.java | 64 +++++----- .../java/com/mindspore/flclient/GetModel.java | 94 +++++++++------ .../mindspore/flclient/LocalFLParameter.java | 2 +- .../mindspore/flclient/SecureProtocol.java | 10 +- .../com/mindspore/flclient/StartFLJob.java | 111 +++++++++++------- .../com/mindspore/flclient/SyncFLJob.java | 42 +++---- .../com/mindspore/flclient/UpdateModel.java | 10 +- .../model/{AdBert.java => AlBert.java} | 4 +- .../{AdInferBert.java => AlInferBert.java} | 16 +-- .../{AdTrainBert.java => AlTrainBert.java} | 16 +-- .../flclient/model/CustomTokenizer.java | 2 +- 12 files changed, 212 insertions(+), 161 deletions(-) rename mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/{AdBert.java => AlBert.java} (97%) rename mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/{AdInferBert.java => AlInferBert.java} (89%) rename mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/{AdTrainBert.java => AlTrainBert.java} (75%) diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/Common.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/Common.java index 2ce24c8ce92..25a0a7652fb 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/Common.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/Common.java @@ -28,7 +28,7 @@ import java.util.regex.Pattern; public class Common { public static final String LOG_TITLE = " "; private static final Logger LOGGER = Logger.getLogger(Common.class.toString()); - private static List flNameTrustList = new ArrayList<>(Arrays.asList("lenet", "adbert")); + private static List flNameTrustList = new ArrayList<>(Arrays.asList("lenet", "albert")); public static String generateUrl(boolean useHttps, boolean useElb, String ip, int port, int serverNum) { if (useHttps) { diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java index f7291ebd2ba..71b417b18a1 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java @@ -17,8 +17,8 @@ package com.mindspore.flclient; import com.mindspore.flclient.cipher.BaseUtil; -import com.mindspore.flclient.model.AdInferBert; -import com.mindspore.flclient.model.AdTrainBert; +import com.mindspore.flclient.model.AlInferBert; +import com.mindspore.flclient.model.AlTrainBert; import com.mindspore.flclient.model.SessionUtil; import com.mindspore.flclient.model.TrainLenet; import mindspore.schema.CipherPublicParams; @@ -37,7 +37,7 @@ import java.util.TreeMap; import java.util.logging.Logger; import static com.mindspore.flclient.FLParameter.SLEEP_TIME; -import static com.mindspore.flclient.LocalFLParameter.ADBERT; +import static com.mindspore.flclient.LocalFLParameter.ALBERT; import static com.mindspore.flclient.LocalFLParameter.LENET; public class FLLiteClient { @@ -83,9 +83,9 @@ public class FLLiteClient { localFLParameter.setServerMod(serverMod); LOGGER.info(Common.addTag("[startFLJob] GlobalParameters from server: " + serverMod)); if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { - LOGGER.info(Common.addTag("[startFLJob] set for AdTrainBert: " + batchSize)); - AdTrainBert adTrainBert = AdTrainBert.getInstance(); - adTrainBert.setBatchSize(batchSize); + LOGGER.info(Common.addTag("[startFLJob] set for AlTrainBert: " + batchSize)); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + alTrainBert.setBatchSize(batchSize); } else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) { LOGGER.info(Common.addTag("[startFLJob] set for TrainLenet: " + batchSize)); TrainLenet trainLenet = TrainLenet.getInstance(); @@ -227,12 +227,12 @@ public class FLLiteClient { LOGGER.info(Common.addTag("[train] ====================================global train epoch " + iteration + "====================================")); status = FLClientStatus.SUCCESS; retCode = ResponseCode.SUCCEED; - if (flParameter.getFlName().equals(ADBERT)) { + if (flParameter.getFlName().equals(ALBERT)) { LOGGER.info(Common.addTag("[train] train in adbert")); - AdTrainBert adTrainBert = AdTrainBert.getInstance(); - int tag = adTrainBert.trainModel(flParameter.getTrainModelPath(), epochs); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + int tag = alTrainBert.trainModel(flParameter.getTrainModelPath(), epochs); if (tag == -1) { - LOGGER.severe(Common.addTag("[train] unsolved error code in ")); + LOGGER.severe(Common.addTag("[train] unsolved error code in ")); status = FLClientStatus.FAILED; retCode = ResponseCode.RequestError; } @@ -366,9 +366,9 @@ public class FLLiteClient { return curStatus; case DP_ENCRYPT: Map map = new HashMap(); - if (flParameter.getFlName().equals(ADBERT)) { - AdTrainBert adTrainBert = AdTrainBert.getInstance(); - map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession())); + if (flParameter.getFlName().equals(ALBERT)) { + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession())); } else if (flParameter.getFlName().equals(LENET)) { TrainLenet trainLenet = TrainLenet.getInstance(); map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); @@ -424,18 +424,18 @@ public class FLLiteClient { status = FLClientStatus.SUCCESS; retCode = ResponseCode.SUCCEED; LOGGER.info(Common.addTag("===================================evaluate model after getting model from server===================================")); - if (flParameter.getFlName().equals(ADBERT)) { - AdInferBert adInferBert = AdInferBert.getInstance(); - int dataSize = adInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile(), true); + if (flParameter.getFlName().equals(ALBERT)) { + AlInferBert alInferBert = AlInferBert.getInstance(); + int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile(), true); if (dataSize <= 0) { - LOGGER.severe(Common.addTag("[evaluate] unsolved error code in : the return dataSize<=0")); + LOGGER.severe(Common.addTag("[evaluate] unsolved error code in : the return dataSize<=0")); status = FLClientStatus.FAILED; retCode = ResponseCode.RequestError; return status; } - float acc = adInferBert.evalModel(); + float acc = alInferBert.evalModel(); if (acc == Float.NaN) { - LOGGER.severe(Common.addTag("[evaluate] unsolved error code in : the return acc is NAN")); + LOGGER.severe(Common.addTag("[evaluate] unsolved error code in : the return acc is NAN")); status = FLClientStatus.FAILED; retCode = ResponseCode.RequestError; return status; @@ -471,9 +471,9 @@ public class FLLiteClient { retCode = ResponseCode.SUCCEED; LOGGER.info(Common.addTag("==========set input===========")); int dataSize = 0; - if (flParameter.getFlName().equals(ADBERT)) { - AdTrainBert adTrainBert = AdTrainBert.getInstance(); - dataSize = adTrainBert.initDataSet(dataPath, flParameter.getVocabFile(), flParameter.getIdsFile()); + if (flParameter.getFlName().equals(ALBERT)) { + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + dataSize = alTrainBert.initDataSet(dataPath, flParameter.getVocabFile(), flParameter.getIdsFile()); LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath + " dataSize: " + +dataSize + " vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile())); } else if (flParameter.getFlName().equals(LENET)) { TrainLenet trainLenet = TrainLenet.getInstance(); @@ -490,18 +490,18 @@ public class FLLiteClient { public FLClientStatus initSession() { int tag = 0; retCode = ResponseCode.SUCCEED; - if (flParameter.getFlName().equals(ADBERT)) { + if (flParameter.getFlName().equals(ALBERT)) { LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session=============")); - AdTrainBert adTrainBert = AdTrainBert.getInstance(); - tag = adTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true); if (tag == -1) { LOGGER.severe(Common.addTag("[initSession] unsolved error code in : the return is -1")); retCode = ResponseCode.RequestError; return FLClientStatus.FAILED; } LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session=============")); - AdInferBert adInferBert = AdInferBert.getInstance(); - tag = adInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false); + AlInferBert alInferBert = AlInferBert.getInstance(); + tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false); } else if (flParameter.getFlName().equals(LENET)) { LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session=============")); TrainLenet trainLenet = TrainLenet.getInstance(); @@ -517,14 +517,14 @@ public class FLLiteClient { @Override protected void finalize() { - if (flParameter.getFlName().equals(ADBERT)) { + if (flParameter.getFlName().equals(ALBERT)) { LOGGER.info(Common.addTag("===========free train session=============")); - AdTrainBert adTrainBert = AdTrainBert.getInstance(); - SessionUtil.free(adTrainBert.getTrainSession()); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + SessionUtil.free(alTrainBert.getTrainSession()); if (!flParameter.getTestDataset().equals("null")) { LOGGER.info(Common.addTag("===========free inference session=============")); - AdInferBert adInferBert = AdInferBert.getInstance(); - SessionUtil.free(adInferBert.getTrainSession()); + AlInferBert alInferBert = AlInferBert.getInstance(); + SessionUtil.free(alInferBert.getTrainSession()); } } else if (flParameter.getFlName().equals(LENET)) { LOGGER.info(Common.addTag("===========free session=============")); diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/GetModel.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/GetModel.java index 9595f9d0308..e8079f02560 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/GetModel.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/GetModel.java @@ -16,8 +16,8 @@ package com.mindspore.flclient; import com.google.flatbuffers.FlatBufferBuilder; -import com.mindspore.flclient.model.AdInferBert; -import com.mindspore.flclient.model.AdTrainBert; +import com.mindspore.flclient.model.AlInferBert; +import com.mindspore.flclient.model.AlTrainBert; import com.mindspore.flclient.model.SessionUtil; import com.mindspore.flclient.model.TrainLenet; import mindspore.schema.FeatureMap; @@ -29,6 +29,9 @@ import java.util.ArrayList; import java.util.Date; import java.util.logging.Logger; +import static com.mindspore.flclient.LocalFLParameter.ALBERT; +import static com.mindspore.flclient.LocalFLParameter.LENET; + public class GetModel { static { System.loadLibrary("mindspore-lite-jni"); @@ -90,37 +93,57 @@ public class GetModel { } - private FLClientStatus parseResponseAdbert(ResponseGetModel responseDataBuf) { + private FLClientStatus parseResponseAlbert(ResponseGetModel responseDataBuf) { int fmCount = responseDataBuf.featureMapLength(); - ArrayList albertFeatureMaps = new ArrayList(); - ArrayList inferFeatureMaps = new ArrayList(); - for (int i = 0; i < fmCount; i++) { - FeatureMap feature = responseDataBuf.featureMap(i); - String featureName = feature.weightFullname(); - if (localFLParameter.getAlbertWeightName().contains(featureName)) { - albertFeatureMaps.add(feature); - inferFeatureMaps.add(feature); - } else if (localFLParameter.getClassifierWeightName().contains(featureName)) { - inferFeatureMaps.add(feature); - } else { - continue; + if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { + LOGGER.info(Common.addTag("[getModel] into ")); + ArrayList albertFeatureMaps = new ArrayList(); + ArrayList inferFeatureMaps = new ArrayList(); + for (int i = 0; i < fmCount; i++) { + FeatureMap feature = responseDataBuf.featureMap(i); + String featureName = feature.weightFullname(); + if (localFLParameter.getAlbertWeightName().contains(featureName)) { + albertFeatureMaps.add(feature); + inferFeatureMaps.add(feature); + } else if (localFLParameter.getClassifierWeightName().contains(featureName)) { + inferFeatureMaps.add(feature); + } else { + continue; + } + LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength())); + } + int tag = 0; + LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference model-----------------")); + AlInferBert alInferBert = AlInferBert.getInstance(); + tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps); + if (tag == -1) { + LOGGER.severe(Common.addTag("[getModel] unsolved error code in ")); + return FLClientStatus.FAILED; + } + LOGGER.info(Common.addTag("[getModel] ----------------loading weight into train model-----------------")); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps); + if (tag == -1) { + LOGGER.severe(Common.addTag("[getModel] unsolved error code in ")); + return FLClientStatus.FAILED; + } + } else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) { + LOGGER.info(Common.addTag("[getModel] into ")); + ArrayList featureMaps = new ArrayList(); + for (int i = 0; i < fmCount; i++) { + FeatureMap feature = responseDataBuf.featureMap(i); + String featureName = feature.weightFullname(); + featureMaps.add(feature); + LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " + feature.dataLength())); + } + int tag = 0; + LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------")); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), featureMaps); + if (tag == -1) { + LOGGER.severe(Common.addTag("[getModel] unsolved error code in ")); + return FLClientStatus.FAILED; } - LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength())); - } - int tag = 0; - LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference model-----------------")); - AdInferBert adInferBert = AdInferBert.getInstance(); - tag = SessionUtil.updateFeatures(adInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps); - if (tag == -1) { - LOGGER.severe(Common.addTag("[getModel] unsolved error code in ")); - return FLClientStatus.FAILED; - } - LOGGER.info(Common.addTag("[getModel] ----------------loading weight into train model-----------------")); - AdTrainBert adTrainBert = AdTrainBert.getInstance(); - tag = SessionUtil.updateFeatures(adTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps); - if (tag == -1) { - LOGGER.severe(Common.addTag("[getModel] unsolved error code in ")); - return FLClientStatus.FAILED; } return FLClientStatus.SUCCESS; } @@ -157,10 +180,11 @@ public class GetModel { switch (retCode) { case (ResponseCode.SUCCEED): LOGGER.info(Common.addTag("[getModel] getModel response success")); - if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { - LOGGER.info(Common.addTag("[getModel] into ")); - status = parseResponseAdbert(responseDataBuf); - } else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) { + + if (ALBERT.equals(flParameter.getFlName())) { + LOGGER.info(Common.addTag("[getModel] into ")); + status = parseResponseAlbert(responseDataBuf); + } else if (LENET.equals(flParameter.getFlName())) { LOGGER.info(Common.addTag("[getModel] into ")); status = parseResponseLenet(responseDataBuf); } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java index fdc32429678..9ceb6736412 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java @@ -24,7 +24,7 @@ public class LocalFLParameter { public static final int SEED_SIZE = 32; public static final int IVEC_LEN = 16; public static final String LENET = "lenet"; - public static final String ADBERT = "adbert"; + public static final String ALBERT = "albert"; private List classifierWeightName = new ArrayList<>(); private List albertWeightName = new ArrayList<>(); diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java index a1c668463ce..6f44180a6c4 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SecureProtocol.java @@ -16,7 +16,7 @@ package com.mindspore.flclient; import com.google.flatbuffers.FlatBufferBuilder; -import com.mindspore.flclient.model.AdTrainBert; +import com.mindspore.flclient.model.AlTrainBert; import com.mindspore.flclient.model.SessionUtil; import com.mindspore.flclient.model.TrainLenet; import mindspore.schema.FeatureMap; @@ -130,8 +130,8 @@ public class SecureProtocol { // get feature map Map map = new HashMap(); if (flParameter.getFlName().equals("adbert")) { - AdTrainBert adTrainBert = AdTrainBert.getInstance(); - map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession())); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession())); } else if (flParameter.getFlName().equals("lenet")) { TrainLenet trainLenet = TrainLenet.getInstance(); map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); @@ -267,8 +267,8 @@ public class SecureProtocol { // get feature map Map map = new HashMap(); if (flParameter.getFlName().equals("adbert")) { - AdTrainBert adTrainBert = AdTrainBert.getInstance(); - map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession())); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession())); } else if (flParameter.getFlName().equals("lenet")) { TrainLenet trainLenet = TrainLenet.getInstance(); map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/StartFLJob.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/StartFLJob.java index 232259d6fe5..3e97fdeffa7 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/StartFLJob.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/StartFLJob.java @@ -16,8 +16,8 @@ package com.mindspore.flclient; import com.google.flatbuffers.FlatBufferBuilder; -import com.mindspore.flclient.model.AdInferBert; -import com.mindspore.flclient.model.AdTrainBert; +import com.mindspore.flclient.model.AlInferBert; +import com.mindspore.flclient.model.AlTrainBert; import com.mindspore.flclient.model.SessionUtil; import com.mindspore.flclient.model.TrainLenet; import mindspore.schema.FeatureMap; @@ -28,6 +28,9 @@ import mindspore.schema.ResponseFLJob; import java.util.ArrayList; import java.util.logging.Logger; +import static com.mindspore.flclient.LocalFLParameter.ALBERT; +import static com.mindspore.flclient.LocalFLParameter.LENET; + public class StartFLJob { static { System.loadLibrary("mindspore-lite-jni"); @@ -126,44 +129,67 @@ public class StartFLJob { return encryptFeatureName; } - private FLClientStatus parseResponseAdbert(ResponseFLJob flJob) { + private FLClientStatus parseResponseAlbert(ResponseFLJob flJob) { int fmCount = flJob.featureMapLength(); - ArrayList albertFeatureMaps = new ArrayList(); - ArrayList inferFeatureMaps = new ArrayList(); encryptFeatureName.clear(); if (fmCount <= 0) { LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero")); return FLClientStatus.FAILED; } - for (int i = 0; i < fmCount; i++) { - FeatureMap feature = flJob.featureMap(i); - String featureName = feature.weightFullname(); - if (localFLParameter.getAlbertWeightName().contains(featureName)) { - albertFeatureMaps.add(feature); - inferFeatureMaps.add(feature); - featureSize += feature.dataLength(); - encryptFeatureName.add(feature.weightFullname()); - } else if (localFLParameter.getClassifierWeightName().contains(featureName)) { - inferFeatureMaps.add(feature); - } else { - continue; + + if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { + LOGGER.info(Common.addTag("[startFLJob] parseResponseAlbert by " + localFLParameter.getServerMod())); + ArrayList albertFeatureMaps = new ArrayList(); + ArrayList inferFeatureMaps = new ArrayList(); + for (int i = 0; i < fmCount; i++) { + FeatureMap feature = flJob.featureMap(i); + String featureName = feature.weightFullname(); + if (localFLParameter.getAlbertWeightName().contains(featureName)) { + albertFeatureMaps.add(feature); + inferFeatureMaps.add(feature); + featureSize += feature.dataLength(); + encryptFeatureName.add(feature.weightFullname()); + } else if (localFLParameter.getClassifierWeightName().contains(featureName)) { + inferFeatureMaps.add(feature); + } else { + continue; + } + LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength())); + } + int tag = 0; + LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference model-----------------")); + AlInferBert alInferBert = AlInferBert.getInstance(); + tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps); + if (tag == -1) { + LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in ")); + return FLClientStatus.FAILED; + } + LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------")); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps); + if (tag == -1) { + LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in ")); + return FLClientStatus.FAILED; + } + } else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) { + LOGGER.info(Common.addTag("[startFLJob] parseResponseAlbert by " + localFLParameter.getServerMod())); + ArrayList featureMaps = new ArrayList(); + for (int i = 0; i < fmCount; i++) { + FeatureMap feature = flJob.featureMap(i); + String featureName = feature.weightFullname(); + featureMaps.add(feature); + featureSize += feature.dataLength(); + encryptFeatureName.add(featureName); + LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength())); + } + int tag = 0; + LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------")); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), featureMaps); + if (tag == -1) { + LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in ")); + return FLClientStatus.FAILED; } - LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength())); - } - int tag = 0; - LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference model-----------------")); - AdInferBert adInferBert = AdInferBert.getInstance(); - tag = SessionUtil.updateFeatures(adInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps); - if (tag == -1) { - LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in ")); - return FLClientStatus.FAILED; - } - LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------")); - AdTrainBert adTrainBert = AdTrainBert.getInstance(); - tag = SessionUtil.updateFeatures(adTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps); - if (tag == -1) { - LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in ")); - return FLClientStatus.FAILED; } return FLClientStatus.SUCCESS; } @@ -199,19 +225,20 @@ public class StartFLJob { LOGGER.info(Common.addTag("[startFLJob] next request time: " + flJob.nextReqTime())); nextRequestTime = flJob.nextReqTime(); LOGGER.info(Common.addTag("[startFLJob] timestamp: " + flJob.timestamp())); - int retcode = flJob.retcode(); + FLClientStatus status = FLClientStatus.SUCCESS; + int retCode = flJob.retcode(); - switch (retcode) { + switch (retCode) { case (ResponseCode.SUCCEED): localFLParameter.setServerMod(flJob.flPlanConfig().serverMode()); - if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { - LOGGER.info(Common.addTag("[startFLJob] into ")); - parseResponseAdbert(flJob); - } else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) { + if (ALBERT.equals(flParameter.getFlName())) { + LOGGER.info(Common.addTag("[startFLJob] into ")); + status = parseResponseAlbert(flJob); + } else if (LENET.equals(flParameter.getFlName())) { LOGGER.info(Common.addTag("[startFLJob] into ")); - parseResponseLenet(flJob); + status = parseResponseLenet(flJob); } - return FLClientStatus.SUCCESS; + return status; case (ResponseCode.OutOfTime): return FLClientStatus.RESTART; case (ResponseCode.RequestError): @@ -219,7 +246,7 @@ public class StartFLJob { LOGGER.info(Common.addTag("[startFLJob] catch RequestError or SystemError")); return FLClientStatus.FAILED; default: - LOGGER.severe(Common.addTag("[startFLJob] the return from server is invalid: " + retcode)); + LOGGER.severe(Common.addTag("[startFLJob] the return from server is invalid: " + retCode)); return FLClientStatus.FAILED; } } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java index 5eec006b140..433a8b8150e 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java @@ -15,8 +15,8 @@ */ package com.mindspore.flclient; -import com.mindspore.flclient.model.AdInferBert; -import com.mindspore.flclient.model.AdTrainBert; +import com.mindspore.flclient.model.AlInferBert; +import com.mindspore.flclient.model.AlTrainBert; import com.mindspore.flclient.model.SessionUtil; import com.mindspore.flclient.model.TrainLenet; import mindspore.schema.ResponseGetModel; @@ -28,7 +28,7 @@ import java.util.Map; import java.util.logging.Logger; import static com.mindspore.flclient.FLParameter.SLEEP_TIME; -import static com.mindspore.flclient.LocalFLParameter.ADBERT; +import static com.mindspore.flclient.LocalFLParameter.ALBERT; import static com.mindspore.flclient.LocalFLParameter.LENET; public class SyncFLJob { @@ -203,9 +203,9 @@ public class SyncFLJob { private Map getFeatureMap() { Map featureMap = new HashMap<>(); - if (flParameter.getFlName().equals(ADBERT)) { - AdTrainBert adTrainBert = AdTrainBert.getInstance(); - featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession())); + if (flParameter.getFlName().equals(ALBERT)) { + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession())); } else if (flParameter.getFlName().equals(LENET)) { TrainLenet trainLenet = TrainLenet.getInstance(); featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); @@ -215,12 +215,12 @@ public class SyncFLJob { public int[] modelInference(String flName, String dataPath, String vocabFile, String idsFile, String modelPath) { int[] labels = new int[0]; - if (flName.equals(ADBERT)) { - AdInferBert adInferBert = AdInferBert.getInstance(); + if (flName.equals(ALBERT)) { + AlInferBert alInferBert = AlInferBert.getInstance(); LOGGER.info(Common.addTag("===========model inference=============")); - labels = adInferBert.inferModel(modelPath, dataPath, vocabFile, idsFile); + labels = alInferBert.inferModel(modelPath, dataPath, vocabFile, idsFile); LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels))); - SessionUtil.free(adInferBert.getTrainSession()); + SessionUtil.free(alInferBert.getTrainSession()); LOGGER.info(Common.addTag("[model inference] inference finish")); } else if (flName.equals(LENET)) { TrainLenet trainLenet = TrainLenet.getInstance(); @@ -240,18 +240,18 @@ public class SyncFLJob { int tag = 0; FLClientStatus status = FLClientStatus.SUCCESS; try { - if (flParameter.getFlName().equals(ADBERT)) { + if (flParameter.getFlName().equals(ALBERT)) { localFLParameter.setServerMod(ServerMod.HYBRID_TRAINING.toString()); LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session=============")); - AdTrainBert adTrainBert = AdTrainBert.getInstance(); - tag = adTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true); if (tag == -1) { LOGGER.severe(Common.addTag("[initSession] unsolved error code in : the return is -1")); return FLClientStatus.FAILED; } LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session=============")); - AdInferBert adInferBert = AdInferBert.getInstance(); - tag = adInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false); + AlInferBert alInferBert = AlInferBert.getInstance(); + tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false); } else if (flParameter.getFlName().equals(LENET)) { localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString()); LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session=============")); @@ -278,13 +278,13 @@ public class SyncFLJob { LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + e.getMessage())); status = FLClientStatus.FAILED; } - if (flParameter.getFlName().equals(ADBERT)) { + if (flParameter.getFlName().equals(ALBERT)) { LOGGER.info(Common.addTag("===========free train session=============")); - AdTrainBert adTrainBert = AdTrainBert.getInstance(); - SessionUtil.free(adTrainBert.getTrainSession()); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + SessionUtil.free(alTrainBert.getTrainSession()); LOGGER.info(Common.addTag("===========free inference session=============")); - AdInferBert adInferBert = AdInferBert.getInstance(); - SessionUtil.free(adInferBert.getTrainSession()); + AlInferBert alInferBert = AlInferBert.getInstance(); + SessionUtil.free(alInferBert.getTrainSession()); } else if (flParameter.getFlName().equals(LENET)) { LOGGER.info(Common.addTag("===========free session=============")); TrainLenet trainLenet = TrainLenet.getInstance(); @@ -376,7 +376,7 @@ public class SyncFLJob { flParameter.setTimeWindow(timeWindow); flParameter.setUseElb(useElb); flParameter.setServerNum(serverNum); - if (ADBERT.equals(flName)) { + if (ALBERT.equals(flName)) { flParameter.setVocabFile(vocabFile); flParameter.setIdsFile(idsFile); } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java index 58e99c444a0..4fe4a1de2a7 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/UpdateModel.java @@ -17,7 +17,7 @@ package com.mindspore.flclient; import com.google.flatbuffers.FlatBufferBuilder; -import com.mindspore.flclient.model.AdTrainBert; +import com.mindspore.flclient.model.AlTrainBert; import com.mindspore.flclient.model.SessionUtil; import com.mindspore.flclient.model.TrainLenet; import mindspore.schema.FeatureMap; @@ -31,7 +31,7 @@ import java.util.HashMap; import java.util.Map; import java.util.logging.Logger; -import static com.mindspore.flclient.LocalFLParameter.ADBERT; +import static com.mindspore.flclient.LocalFLParameter.ALBERT; import static com.mindspore.flclient.LocalFLParameter.LENET; public class UpdateModel { @@ -100,10 +100,10 @@ public class UpdateModel { case NOT_ENCRYPT: default: Map map = new HashMap(); - if (flParameter.getFlName().equals(ADBERT)) { + if (flParameter.getFlName().equals(ALBERT)) { LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + flParameter.getFlName())); - AdTrainBert adTrainBert = AdTrainBert.getInstance(); - map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession())); + AlTrainBert alTrainBert = AlTrainBert.getInstance(); + map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession())); if (map.isEmpty()) { LOGGER.severe(Common.addTag("[updateModel] the return map is empty in ")); status = FLClientStatus.FAILED; diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AdBert.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AlBert.java similarity index 97% rename from mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AdBert.java rename to mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AlBert.java index fa2898009c2..fd66a21bcde 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AdBert.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AlBert.java @@ -25,8 +25,8 @@ import java.util.ArrayList; import java.util.List; import java.util.logging.Logger; -public class AdBert extends TrainModel { - private static final Logger logger = Logger.getLogger(AdBert.class.toString()); +public class AlBert extends TrainModel { + private static final Logger logger = Logger.getLogger(AlBert.class.toString()); private static final int NUM_OF_CLASS = 4; diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AdInferBert.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AlInferBert.java similarity index 89% rename from mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AdInferBert.java rename to mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AlInferBert.java index de9ba8e7841..37368c7add5 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AdInferBert.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AlInferBert.java @@ -21,18 +21,18 @@ import com.mindspore.flclient.Common; import java.util.Arrays; import java.util.logging.Logger; -public class AdInferBert extends AdBert { - private static final Logger logger = Logger.getLogger(AdInferBert.class.toString()); +public class AlInferBert extends AlBert { + private static final Logger logger = Logger.getLogger(AlInferBert.class.toString()); - private static volatile AdInferBert adInferBert; + private static volatile AlInferBert alInferBert; - public static AdInferBert getInstance() { - AdInferBert localRef = adInferBert; + public static AlInferBert getInstance() { + AlInferBert localRef = alInferBert; if (localRef == null) { - synchronized (AdInferBert.class) { - localRef = adInferBert; + synchronized (AlInferBert.class) { + localRef = alInferBert; if (localRef == null) { - adInferBert = localRef = new AdInferBert(); + alInferBert = localRef = new AlInferBert(); } } } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AdTrainBert.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AlTrainBert.java similarity index 75% rename from mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AdTrainBert.java rename to mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AlTrainBert.java index 2851d8a49e9..338f2595346 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AdTrainBert.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/AlTrainBert.java @@ -20,18 +20,18 @@ import com.mindspore.flclient.Common; import java.util.logging.Logger; -public class AdTrainBert extends AdBert { - private static final Logger logger = Logger.getLogger(AdTrainBert.class.toString()); +public class AlTrainBert extends AlBert { + private static final Logger logger = Logger.getLogger(AlTrainBert.class.toString()); - private static volatile AdTrainBert adTrainBert; + private static volatile AlTrainBert alTrainBert; - public static AdTrainBert getInstance() { - AdTrainBert localRef = adTrainBert; + public static AlTrainBert getInstance() { + AlTrainBert localRef = alTrainBert; if (localRef == null) { - synchronized (AdTrainBert.class) { - localRef = adTrainBert; + synchronized (AlTrainBert.class) { + localRef = alTrainBert; if (localRef == null) { - adTrainBert = localRef = new AdTrainBert(); + alTrainBert = localRef = new AlTrainBert(); } } } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/CustomTokenizer.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/CustomTokenizer.java index 977eb3313b5..0266bede0e5 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/CustomTokenizer.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/CustomTokenizer.java @@ -27,7 +27,7 @@ import java.util.*; import java.util.logging.Logger; public class CustomTokenizer { - private static final Logger logger = Logger.getLogger(AdInferBert.class.toString()); + private static final Logger logger = Logger.getLogger(CustomTokenizer.class.toString()); private Map vocabs = new HashMap<>(); private Boolean doLowerCase = Boolean.TRUE; private int maxInputChars = 100;