add code of albert for fl_client

This commit is contained in:
zhoushan 2021-07-07 16:11:01 +08:00
parent a0225a5481
commit 0c985c0703
12 changed files with 212 additions and 161 deletions

View File

@ -28,7 +28,7 @@ import java.util.regex.Pattern;
public class Common {
public static final String LOG_TITLE = "<FLClient> ";
private static final Logger LOGGER = Logger.getLogger(Common.class.toString());
private static List<String> flNameTrustList = new ArrayList<>(Arrays.asList("lenet", "adbert"));
private static List<String> flNameTrustList = new ArrayList<>(Arrays.asList("lenet", "albert"));
public static String generateUrl(boolean useHttps, boolean useElb, String ip, int port, int serverNum) {
if (useHttps) {

View File

@ -17,8 +17,8 @@
package com.mindspore.flclient;
import com.mindspore.flclient.cipher.BaseUtil;
import com.mindspore.flclient.model.AdInferBert;
import com.mindspore.flclient.model.AdTrainBert;
import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.CipherPublicParams;
@ -37,7 +37,7 @@ import java.util.TreeMap;
import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.ADBERT;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
public class FLLiteClient {
@ -83,9 +83,9 @@ public class FLLiteClient {
localFLParameter.setServerMod(serverMod);
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <serverMod> from server: " + serverMod));
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for AdTrainBert: " + batchSize));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
adTrainBert.setBatchSize(batchSize);
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for AlTrainBert: " + batchSize));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
alTrainBert.setBatchSize(batchSize);
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for TrainLenet: " + batchSize));
TrainLenet trainLenet = TrainLenet.getInstance();
@ -227,12 +227,12 @@ public class FLLiteClient {
LOGGER.info(Common.addTag("[train] ====================================global train epoch " + iteration + "===================================="));
status = FLClientStatus.SUCCESS;
retCode = ResponseCode.SUCCEED;
if (flParameter.getFlName().equals(ADBERT)) {
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("[train] train in adbert"));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
int tag = adTrainBert.trainModel(flParameter.getTrainModelPath(), epochs);
AlTrainBert alTrainBert = AlTrainBert.getInstance();
int tag = alTrainBert.trainModel(flParameter.getTrainModelPath(), epochs);
if (tag == -1) {
LOGGER.severe(Common.addTag("[train] unsolved error code in <adTrainBert.trainModel>"));
LOGGER.severe(Common.addTag("[train] unsolved error code in <alTrainBert.trainModel>"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
}
@ -366,9 +366,9 @@ public class FLLiteClient {
return curStatus;
case DP_ENCRYPT:
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals(ADBERT)) {
AdTrainBert adTrainBert = AdTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
if (flParameter.getFlName().equals(ALBERT)) {
AlTrainBert alTrainBert = AlTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
@ -424,18 +424,18 @@ public class FLLiteClient {
status = FLClientStatus.SUCCESS;
retCode = ResponseCode.SUCCEED;
LOGGER.info(Common.addTag("===================================evaluate model after getting model from server==================================="));
if (flParameter.getFlName().equals(ADBERT)) {
AdInferBert adInferBert = AdInferBert.getInstance();
int dataSize = adInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile(), true);
if (flParameter.getFlName().equals(ALBERT)) {
AlInferBert alInferBert = AlInferBert.getInstance();
int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile(), true);
if (dataSize <= 0) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <adTrainBert.initDataSet>: the return dataSize<=0"));
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alTrainBert.initDataSet>: the return dataSize<=0"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
}
float acc = adInferBert.evalModel();
float acc = alInferBert.evalModel();
if (acc == Float.NaN) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <adTrainBert.evalModel>: the return acc is NAN"));
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alTrainBert.evalModel>: the return acc is NAN"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
@ -471,9 +471,9 @@ public class FLLiteClient {
retCode = ResponseCode.SUCCEED;
LOGGER.info(Common.addTag("==========set input==========="));
int dataSize = 0;
if (flParameter.getFlName().equals(ADBERT)) {
AdTrainBert adTrainBert = AdTrainBert.getInstance();
dataSize = adTrainBert.initDataSet(dataPath, flParameter.getVocabFile(), flParameter.getIdsFile());
if (flParameter.getFlName().equals(ALBERT)) {
AlTrainBert alTrainBert = AlTrainBert.getInstance();
dataSize = alTrainBert.initDataSet(dataPath, flParameter.getVocabFile(), flParameter.getIdsFile());
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath + " dataSize: " + +dataSize + " vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile()));
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
@ -490,18 +490,18 @@ public class FLLiteClient {
public FLClientStatus initSession() {
int tag = 0;
retCode = ResponseCode.SUCCEED;
if (flParameter.getFlName().equals(ADBERT)) {
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
tag = adTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session============="));
AdInferBert adInferBert = AdInferBert.getInstance();
tag = adInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
AlInferBert alInferBert = AlInferBert.getInstance();
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
TrainLenet trainLenet = TrainLenet.getInstance();
@ -517,14 +517,14 @@ public class FLLiteClient {
@Override
protected void finalize() {
if (flParameter.getFlName().equals(ADBERT)) {
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("===========free train session============="));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
SessionUtil.free(adTrainBert.getTrainSession());
AlTrainBert alTrainBert = AlTrainBert.getInstance();
SessionUtil.free(alTrainBert.getTrainSession());
if (!flParameter.getTestDataset().equals("null")) {
LOGGER.info(Common.addTag("===========free inference session============="));
AdInferBert adInferBert = AdInferBert.getInstance();
SessionUtil.free(adInferBert.getTrainSession());
AlInferBert alInferBert = AlInferBert.getInstance();
SessionUtil.free(alInferBert.getTrainSession());
}
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("===========free session============="));

View File

@ -16,8 +16,8 @@
package com.mindspore.flclient;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AdInferBert;
import com.mindspore.flclient.model.AdTrainBert;
import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.FeatureMap;
@ -29,6 +29,9 @@ import java.util.ArrayList;
import java.util.Date;
import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
public class GetModel {
static {
System.loadLibrary("mindspore-lite-jni");
@ -90,37 +93,57 @@ public class GetModel {
}
private FLClientStatus parseResponseAdbert(ResponseGetModel responseDataBuf) {
private FLClientStatus parseResponseAlbert(ResponseGetModel responseDataBuf) {
int fmCount = responseDataBuf.featureMapLength();
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = responseDataBuf.featureMap(i);
String featureName = feature.weightFullname();
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
albertFeatureMaps.add(feature);
inferFeatureMaps.add(feature);
} else if (localFLParameter.getClassifierWeightName().contains(featureName)) {
inferFeatureMaps.add(feature);
} else {
continue;
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseAdbert>"));
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = responseDataBuf.featureMap(i);
String featureName = feature.weightFullname();
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
albertFeatureMaps.add(feature);
inferFeatureMaps.add(feature);
} else if (localFLParameter.getClassifierWeightName().contains(featureName)) {
inferFeatureMaps.add(feature);
} else {
continue;
}
LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
}
int tag = 0;
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference model-----------------"));
AlInferBert alInferBert = AlInferBert.getInstance();
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into train model-----------------"));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>"));
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = responseDataBuf.featureMap(i);
String featureName = feature.weightFullname();
featureMaps.add(feature);
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " + feature.dataLength()));
}
int tag = 0;
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), featureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
}
int tag = 0;
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference model-----------------"));
AdInferBert adInferBert = AdInferBert.getInstance();
tag = SessionUtil.updateFeatures(adInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into train model-----------------"));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
tag = SessionUtil.updateFeatures(adTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
return FLClientStatus.SUCCESS;
}
@ -157,10 +180,11 @@ public class GetModel {
switch (retCode) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[getModel] getModel response success"));
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseAdbert>"));
status = parseResponseAdbert(responseDataBuf);
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
if (ALBERT.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
status = parseResponseAlbert(responseDataBuf);
} else if (LENET.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>"));
status = parseResponseLenet(responseDataBuf);
}

View File

@ -24,7 +24,7 @@ public class LocalFLParameter {
public static final int SEED_SIZE = 32;
public static final int IVEC_LEN = 16;
public static final String LENET = "lenet";
public static final String ADBERT = "adbert";
public static final String ALBERT = "albert";
private List<String> classifierWeightName = new ArrayList<>();
private List<String> albertWeightName = new ArrayList<>();

View File

@ -16,7 +16,7 @@
package com.mindspore.flclient;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AdTrainBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.FeatureMap;
@ -130,8 +130,8 @@ public class SecureProtocol {
// get feature map
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals("adbert")) {
AdTrainBert adTrainBert = AdTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
} else if (flParameter.getFlName().equals("lenet")) {
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
@ -267,8 +267,8 @@ public class SecureProtocol {
// get feature map
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals("adbert")) {
AdTrainBert adTrainBert = AdTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
} else if (flParameter.getFlName().equals("lenet")) {
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));

View File

@ -16,8 +16,8 @@
package com.mindspore.flclient;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AdInferBert;
import com.mindspore.flclient.model.AdTrainBert;
import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.FeatureMap;
@ -28,6 +28,9 @@ import mindspore.schema.ResponseFLJob;
import java.util.ArrayList;
import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
public class StartFLJob {
static {
System.loadLibrary("mindspore-lite-jni");
@ -126,44 +129,67 @@ public class StartFLJob {
return encryptFeatureName;
}
private FLClientStatus parseResponseAdbert(ResponseFLJob flJob) {
private FLClientStatus parseResponseAlbert(ResponseFLJob flJob) {
int fmCount = flJob.featureMapLength();
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
encryptFeatureName.clear();
if (fmCount <= 0) {
LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero"));
return FLClientStatus.FAILED;
}
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
String featureName = feature.weightFullname();
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
albertFeatureMaps.add(feature);
inferFeatureMaps.add(feature);
featureSize += feature.dataLength();
encryptFeatureName.add(feature.weightFullname());
} else if (localFLParameter.getClassifierWeightName().contains(featureName)) {
inferFeatureMaps.add(feature);
} else {
continue;
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[startFLJob] parseResponseAlbert by " + localFLParameter.getServerMod()));
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
String featureName = feature.weightFullname();
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
albertFeatureMaps.add(feature);
inferFeatureMaps.add(feature);
featureSize += feature.dataLength();
encryptFeatureName.add(feature.weightFullname());
} else if (localFLParameter.getClassifierWeightName().contains(featureName)) {
inferFeatureMaps.add(feature);
} else {
continue;
}
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
}
int tag = 0;
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference model-----------------"));
AlInferBert alInferBert = AlInferBert.getInstance();
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------"));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
LOGGER.info(Common.addTag("[startFLJob] parseResponseAlbert by " + localFLParameter.getServerMod()));
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
String featureName = feature.weightFullname();
featureMaps.add(feature);
featureSize += feature.dataLength();
encryptFeatureName.add(featureName);
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
}
int tag = 0;
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), featureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
}
int tag = 0;
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference model-----------------"));
AdInferBert adInferBert = AdInferBert.getInstance();
tag = SessionUtil.updateFeatures(adInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------"));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
tag = SessionUtil.updateFeatures(adTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
return FLClientStatus.SUCCESS;
}
@ -199,19 +225,20 @@ public class StartFLJob {
LOGGER.info(Common.addTag("[startFLJob] next request time: " + flJob.nextReqTime()));
nextRequestTime = flJob.nextReqTime();
LOGGER.info(Common.addTag("[startFLJob] timestamp: " + flJob.timestamp()));
int retcode = flJob.retcode();
FLClientStatus status = FLClientStatus.SUCCESS;
int retCode = flJob.retcode();
switch (retcode) {
switch (retCode) {
case (ResponseCode.SUCCEED):
localFLParameter.setServerMod(flJob.flPlanConfig().serverMode());
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseAdbert>"));
parseResponseAdbert(flJob);
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
if (ALBERT.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseAlbert>"));
status = parseResponseAlbert(flJob);
} else if (LENET.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseLenet>"));
parseResponseLenet(flJob);
status = parseResponseLenet(flJob);
}
return FLClientStatus.SUCCESS;
return status;
case (ResponseCode.OutOfTime):
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
@ -219,7 +246,7 @@ public class StartFLJob {
LOGGER.info(Common.addTag("[startFLJob] catch RequestError or SystemError"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[startFLJob] the return <retcode> from server is invalid: " + retcode));
LOGGER.severe(Common.addTag("[startFLJob] the return <retcode> from server is invalid: " + retCode));
return FLClientStatus.FAILED;
}
}

View File

@ -15,8 +15,8 @@
*/
package com.mindspore.flclient;
import com.mindspore.flclient.model.AdInferBert;
import com.mindspore.flclient.model.AdTrainBert;
import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.ResponseGetModel;
@ -28,7 +28,7 @@ import java.util.Map;
import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.ADBERT;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
public class SyncFLJob {
@ -203,9 +203,9 @@ public class SyncFLJob {
private Map<String, float[]> getFeatureMap() {
Map<String, float[]> featureMap = new HashMap<>();
if (flParameter.getFlName().equals(ADBERT)) {
AdTrainBert adTrainBert = AdTrainBert.getInstance();
featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
if (flParameter.getFlName().equals(ALBERT)) {
AlTrainBert alTrainBert = AlTrainBert.getInstance();
featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
@ -215,12 +215,12 @@ public class SyncFLJob {
public int[] modelInference(String flName, String dataPath, String vocabFile, String idsFile, String modelPath) {
int[] labels = new int[0];
if (flName.equals(ADBERT)) {
AdInferBert adInferBert = AdInferBert.getInstance();
if (flName.equals(ALBERT)) {
AlInferBert alInferBert = AlInferBert.getInstance();
LOGGER.info(Common.addTag("===========model inference============="));
labels = adInferBert.inferModel(modelPath, dataPath, vocabFile, idsFile);
labels = alInferBert.inferModel(modelPath, dataPath, vocabFile, idsFile);
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
SessionUtil.free(adInferBert.getTrainSession());
SessionUtil.free(alInferBert.getTrainSession());
LOGGER.info(Common.addTag("[model inference] inference finish"));
} else if (flName.equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
@ -240,18 +240,18 @@ public class SyncFLJob {
int tag = 0;
FLClientStatus status = FLClientStatus.SUCCESS;
try {
if (flParameter.getFlName().equals(ADBERT)) {
if (flParameter.getFlName().equals(ALBERT)) {
localFLParameter.setServerMod(ServerMod.HYBRID_TRAINING.toString());
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
tag = adTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session============="));
AdInferBert adInferBert = AdInferBert.getInstance();
tag = adInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
AlInferBert alInferBert = AlInferBert.getInstance();
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
} else if (flParameter.getFlName().equals(LENET)) {
localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString());
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
@ -278,13 +278,13 @@ public class SyncFLJob {
LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + e.getMessage()));
status = FLClientStatus.FAILED;
}
if (flParameter.getFlName().equals(ADBERT)) {
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("===========free train session============="));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
SessionUtil.free(adTrainBert.getTrainSession());
AlTrainBert alTrainBert = AlTrainBert.getInstance();
SessionUtil.free(alTrainBert.getTrainSession());
LOGGER.info(Common.addTag("===========free inference session============="));
AdInferBert adInferBert = AdInferBert.getInstance();
SessionUtil.free(adInferBert.getTrainSession());
AlInferBert alInferBert = AlInferBert.getInstance();
SessionUtil.free(alInferBert.getTrainSession());
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("===========free session============="));
TrainLenet trainLenet = TrainLenet.getInstance();
@ -376,7 +376,7 @@ public class SyncFLJob {
flParameter.setTimeWindow(timeWindow);
flParameter.setUseElb(useElb);
flParameter.setServerNum(serverNum);
if (ADBERT.equals(flName)) {
if (ALBERT.equals(flName)) {
flParameter.setVocabFile(vocabFile);
flParameter.setIdsFile(idsFile);
}

View File

@ -17,7 +17,7 @@
package com.mindspore.flclient;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AdTrainBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.FeatureMap;
@ -31,7 +31,7 @@ import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ADBERT;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
public class UpdateModel {
@ -100,10 +100,10 @@ public class UpdateModel {
case NOT_ENCRYPT:
default:
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals(ADBERT)) {
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + flParameter.getFlName()));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
if (map.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil.convertTensorToFeatures>"));
status = FLClientStatus.FAILED;

View File

@ -25,8 +25,8 @@ import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;
public class AdBert extends TrainModel {
private static final Logger logger = Logger.getLogger(AdBert.class.toString());
public class AlBert extends TrainModel {
private static final Logger logger = Logger.getLogger(AlBert.class.toString());
private static final int NUM_OF_CLASS = 4;

View File

@ -21,18 +21,18 @@ import com.mindspore.flclient.Common;
import java.util.Arrays;
import java.util.logging.Logger;
public class AdInferBert extends AdBert {
private static final Logger logger = Logger.getLogger(AdInferBert.class.toString());
public class AlInferBert extends AlBert {
private static final Logger logger = Logger.getLogger(AlInferBert.class.toString());
private static volatile AdInferBert adInferBert;
private static volatile AlInferBert alInferBert;
public static AdInferBert getInstance() {
AdInferBert localRef = adInferBert;
public static AlInferBert getInstance() {
AlInferBert localRef = alInferBert;
if (localRef == null) {
synchronized (AdInferBert.class) {
localRef = adInferBert;
synchronized (AlInferBert.class) {
localRef = alInferBert;
if (localRef == null) {
adInferBert = localRef = new AdInferBert();
alInferBert = localRef = new AlInferBert();
}
}
}

View File

@ -20,18 +20,18 @@ import com.mindspore.flclient.Common;
import java.util.logging.Logger;
public class AdTrainBert extends AdBert {
private static final Logger logger = Logger.getLogger(AdTrainBert.class.toString());
public class AlTrainBert extends AlBert {
private static final Logger logger = Logger.getLogger(AlTrainBert.class.toString());
private static volatile AdTrainBert adTrainBert;
private static volatile AlTrainBert alTrainBert;
public static AdTrainBert getInstance() {
AdTrainBert localRef = adTrainBert;
public static AlTrainBert getInstance() {
AlTrainBert localRef = alTrainBert;
if (localRef == null) {
synchronized (AdTrainBert.class) {
localRef = adTrainBert;
synchronized (AlTrainBert.class) {
localRef = alTrainBert;
if (localRef == null) {
adTrainBert = localRef = new AdTrainBert();
alTrainBert = localRef = new AlTrainBert();
}
}
}

View File

@ -27,7 +27,7 @@ import java.util.*;
import java.util.logging.Logger;
public class CustomTokenizer {
private static final Logger logger = Logger.getLogger(AdInferBert.class.toString());
private static final Logger logger = Logger.getLogger(CustomTokenizer.class.toString());
private Map<String, Integer> vocabs = new HashMap<>();
private Boolean doLowerCase = Boolean.TRUE;
private int maxInputChars = 100;