forked from mindspore-Ecosystem/mindspore
add code of albert for fl_client
This commit is contained in:
parent
a0225a5481
commit
0c985c0703
|
@ -28,7 +28,7 @@ import java.util.regex.Pattern;
|
|||
public class Common {
|
||||
public static final String LOG_TITLE = "<FLClient> ";
|
||||
private static final Logger LOGGER = Logger.getLogger(Common.class.toString());
|
||||
private static List<String> flNameTrustList = new ArrayList<>(Arrays.asList("lenet", "adbert"));
|
||||
private static List<String> flNameTrustList = new ArrayList<>(Arrays.asList("lenet", "albert"));
|
||||
|
||||
public static String generateUrl(boolean useHttps, boolean useElb, String ip, int port, int serverNum) {
|
||||
if (useHttps) {
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
package com.mindspore.flclient;
|
||||
|
||||
import com.mindspore.flclient.cipher.BaseUtil;
|
||||
import com.mindspore.flclient.model.AdInferBert;
|
||||
import com.mindspore.flclient.model.AdTrainBert;
|
||||
import com.mindspore.flclient.model.AlInferBert;
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import mindspore.schema.CipherPublicParams;
|
||||
|
@ -37,7 +37,7 @@ import java.util.TreeMap;
|
|||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||
import static com.mindspore.flclient.LocalFLParameter.ADBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
public class FLLiteClient {
|
||||
|
@ -83,9 +83,9 @@ public class FLLiteClient {
|
|||
localFLParameter.setServerMod(serverMod);
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <serverMod> from server: " + serverMod));
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for AdTrainBert: " + batchSize));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
adTrainBert.setBatchSize(batchSize);
|
||||
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for AlTrainBert: " + batchSize));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
alTrainBert.setBatchSize(batchSize);
|
||||
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for TrainLenet: " + batchSize));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
|
@ -227,12 +227,12 @@ public class FLLiteClient {
|
|||
LOGGER.info(Common.addTag("[train] ====================================global train epoch " + iteration + "===================================="));
|
||||
status = FLClientStatus.SUCCESS;
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
LOGGER.info(Common.addTag("[train] train in adbert"));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
int tag = adTrainBert.trainModel(flParameter.getTrainModelPath(), epochs);
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
int tag = alTrainBert.trainModel(flParameter.getTrainModelPath(), epochs);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[train] unsolved error code in <adTrainBert.trainModel>"));
|
||||
LOGGER.severe(Common.addTag("[train] unsolved error code in <alTrainBert.trainModel>"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
}
|
||||
|
@ -366,9 +366,9 @@ public class FLLiteClient {
|
|||
return curStatus;
|
||||
case DP_ENCRYPT:
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
|
@ -424,18 +424,18 @@ public class FLLiteClient {
|
|||
status = FLClientStatus.SUCCESS;
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
LOGGER.info(Common.addTag("===================================evaluate model after getting model from server==================================="));
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
AdInferBert adInferBert = AdInferBert.getInstance();
|
||||
int dataSize = adInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile(), true);
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile(), true);
|
||||
if (dataSize <= 0) {
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <adTrainBert.initDataSet>: the return dataSize<=0"));
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alTrainBert.initDataSet>: the return dataSize<=0"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
float acc = adInferBert.evalModel();
|
||||
float acc = alInferBert.evalModel();
|
||||
if (acc == Float.NaN) {
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <adTrainBert.evalModel>: the return acc is NAN"));
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alTrainBert.evalModel>: the return acc is NAN"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
|
@ -471,9 +471,9 @@ public class FLLiteClient {
|
|||
retCode = ResponseCode.SUCCEED;
|
||||
LOGGER.info(Common.addTag("==========set input==========="));
|
||||
int dataSize = 0;
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
dataSize = adTrainBert.initDataSet(dataPath, flParameter.getVocabFile(), flParameter.getIdsFile());
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
dataSize = alTrainBert.initDataSet(dataPath, flParameter.getVocabFile(), flParameter.getIdsFile());
|
||||
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath + " dataSize: " + +dataSize + " vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile()));
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
|
@ -490,18 +490,18 @@ public class FLLiteClient {
|
|||
public FLClientStatus initSession() {
|
||||
int tag = 0;
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
tag = adTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session============="));
|
||||
AdInferBert adInferBert = AdInferBert.getInstance();
|
||||
tag = adInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
|
@ -517,14 +517,14 @@ public class FLLiteClient {
|
|||
|
||||
@Override
|
||||
protected void finalize() {
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
LOGGER.info(Common.addTag("===========free train session============="));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
SessionUtil.free(adTrainBert.getTrainSession());
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
SessionUtil.free(alTrainBert.getTrainSession());
|
||||
if (!flParameter.getTestDataset().equals("null")) {
|
||||
LOGGER.info(Common.addTag("===========free inference session============="));
|
||||
AdInferBert adInferBert = AdInferBert.getInstance();
|
||||
SessionUtil.free(adInferBert.getTrainSession());
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
SessionUtil.free(alInferBert.getTrainSession());
|
||||
}
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("===========free session============="));
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
package com.mindspore.flclient;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import com.mindspore.flclient.model.AdInferBert;
|
||||
import com.mindspore.flclient.model.AdTrainBert;
|
||||
import com.mindspore.flclient.model.AlInferBert;
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import mindspore.schema.FeatureMap;
|
||||
|
@ -29,6 +29,9 @@ import java.util.ArrayList;
|
|||
import java.util.Date;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
public class GetModel {
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-jni");
|
||||
|
@ -90,37 +93,57 @@ public class GetModel {
|
|||
}
|
||||
|
||||
|
||||
private FLClientStatus parseResponseAdbert(ResponseGetModel responseDataBuf) {
|
||||
private FLClientStatus parseResponseAlbert(ResponseGetModel responseDataBuf) {
|
||||
int fmCount = responseDataBuf.featureMapLength();
|
||||
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
|
||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = responseDataBuf.featureMap(i);
|
||||
String featureName = feature.weightFullname();
|
||||
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
|
||||
albertFeatureMaps.add(feature);
|
||||
inferFeatureMaps.add(feature);
|
||||
} else if (localFLParameter.getClassifierWeightName().contains(featureName)) {
|
||||
inferFeatureMaps.add(feature);
|
||||
} else {
|
||||
continue;
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseAdbert>"));
|
||||
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
|
||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = responseDataBuf.featureMap(i);
|
||||
String featureName = feature.weightFullname();
|
||||
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
|
||||
albertFeatureMaps.add(feature);
|
||||
inferFeatureMaps.add(feature);
|
||||
} else if (localFLParameter.getClassifierWeightName().contains(featureName)) {
|
||||
inferFeatureMaps.add(feature);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference model-----------------"));
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into train model-----------------"));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>"));
|
||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = responseDataBuf.featureMap(i);
|
||||
String featureName = feature.weightFullname();
|
||||
featureMaps.add(feature);
|
||||
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " + feature.dataLength()));
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), featureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference model-----------------"));
|
||||
AdInferBert adInferBert = AdInferBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(adInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into train model-----------------"));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(adTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
@ -157,10 +180,11 @@ public class GetModel {
|
|||
switch (retCode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[getModel] getModel response success"));
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseAdbert>"));
|
||||
status = parseResponseAdbert(responseDataBuf);
|
||||
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
||||
|
||||
if (ALBERT.equals(flParameter.getFlName())) {
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
|
||||
status = parseResponseAlbert(responseDataBuf);
|
||||
} else if (LENET.equals(flParameter.getFlName())) {
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>"));
|
||||
status = parseResponseLenet(responseDataBuf);
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ public class LocalFLParameter {
|
|||
public static final int SEED_SIZE = 32;
|
||||
public static final int IVEC_LEN = 16;
|
||||
public static final String LENET = "lenet";
|
||||
public static final String ADBERT = "adbert";
|
||||
public static final String ALBERT = "albert";
|
||||
private List<String> classifierWeightName = new ArrayList<>();
|
||||
private List<String> albertWeightName = new ArrayList<>();
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
package com.mindspore.flclient;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import com.mindspore.flclient.model.AdTrainBert;
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import mindspore.schema.FeatureMap;
|
||||
|
@ -130,8 +130,8 @@ public class SecureProtocol {
|
|||
// get feature map
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
if (flParameter.getFlName().equals("adbert")) {
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
||||
} else if (flParameter.getFlName().equals("lenet")) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
|
@ -267,8 +267,8 @@ public class SecureProtocol {
|
|||
// get feature map
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
if (flParameter.getFlName().equals("adbert")) {
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
||||
} else if (flParameter.getFlName().equals("lenet")) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
package com.mindspore.flclient;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import com.mindspore.flclient.model.AdInferBert;
|
||||
import com.mindspore.flclient.model.AdTrainBert;
|
||||
import com.mindspore.flclient.model.AlInferBert;
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import mindspore.schema.FeatureMap;
|
||||
|
@ -28,6 +28,9 @@ import mindspore.schema.ResponseFLJob;
|
|||
import java.util.ArrayList;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
public class StartFLJob {
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-jni");
|
||||
|
@ -126,44 +129,67 @@ public class StartFLJob {
|
|||
return encryptFeatureName;
|
||||
}
|
||||
|
||||
private FLClientStatus parseResponseAdbert(ResponseFLJob flJob) {
|
||||
private FLClientStatus parseResponseAlbert(ResponseFLJob flJob) {
|
||||
int fmCount = flJob.featureMapLength();
|
||||
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
|
||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||
encryptFeatureName.clear();
|
||||
if (fmCount <= 0) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature size get from server is zero"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
String featureName = feature.weightFullname();
|
||||
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
|
||||
albertFeatureMaps.add(feature);
|
||||
inferFeatureMaps.add(feature);
|
||||
featureSize += feature.dataLength();
|
||||
encryptFeatureName.add(feature.weightFullname());
|
||||
} else if (localFLParameter.getClassifierWeightName().contains(featureName)) {
|
||||
inferFeatureMaps.add(feature);
|
||||
} else {
|
||||
continue;
|
||||
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] parseResponseAlbert by " + localFLParameter.getServerMod()));
|
||||
ArrayList<FeatureMap> albertFeatureMaps = new ArrayList<FeatureMap>();
|
||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
String featureName = feature.weightFullname();
|
||||
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
|
||||
albertFeatureMaps.add(feature);
|
||||
inferFeatureMaps.add(feature);
|
||||
featureSize += feature.dataLength();
|
||||
encryptFeatureName.add(feature.weightFullname());
|
||||
} else if (localFLParameter.getClassifierWeightName().contains(featureName)) {
|
||||
inferFeatureMaps.add(feature);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference model-----------------"));
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------"));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] parseResponseAlbert by " + localFLParameter.getServerMod()));
|
||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
String featureName = feature.weightFullname();
|
||||
featureMaps.add(feature);
|
||||
featureSize += feature.dataLength();
|
||||
encryptFeatureName.add(featureName);
|
||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), featureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference model-----------------"));
|
||||
AdInferBert adInferBert = AdInferBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(adInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------"));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(adTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
@ -199,19 +225,20 @@ public class StartFLJob {
|
|||
LOGGER.info(Common.addTag("[startFLJob] next request time: " + flJob.nextReqTime()));
|
||||
nextRequestTime = flJob.nextReqTime();
|
||||
LOGGER.info(Common.addTag("[startFLJob] timestamp: " + flJob.timestamp()));
|
||||
int retcode = flJob.retcode();
|
||||
FLClientStatus status = FLClientStatus.SUCCESS;
|
||||
int retCode = flJob.retcode();
|
||||
|
||||
switch (retcode) {
|
||||
switch (retCode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
localFLParameter.setServerMod(flJob.flPlanConfig().serverMode());
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseAdbert>"));
|
||||
parseResponseAdbert(flJob);
|
||||
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
||||
if (ALBERT.equals(flParameter.getFlName())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseAlbert>"));
|
||||
status = parseResponseAlbert(flJob);
|
||||
} else if (LENET.equals(flParameter.getFlName())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseLenet>"));
|
||||
parseResponseLenet(flJob);
|
||||
status = parseResponseLenet(flJob);
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
return status;
|
||||
case (ResponseCode.OutOfTime):
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
|
@ -219,7 +246,7 @@ public class StartFLJob {
|
|||
LOGGER.info(Common.addTag("[startFLJob] catch RequestError or SystemError"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the return <retcode> from server is invalid: " + retcode));
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the return <retcode> from server is invalid: " + retCode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
*/
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import com.mindspore.flclient.model.AdInferBert;
|
||||
import com.mindspore.flclient.model.AdTrainBert;
|
||||
import com.mindspore.flclient.model.AlInferBert;
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import mindspore.schema.ResponseGetModel;
|
||||
|
@ -28,7 +28,7 @@ import java.util.Map;
|
|||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||
import static com.mindspore.flclient.LocalFLParameter.ADBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
public class SyncFLJob {
|
||||
|
@ -203,9 +203,9 @@ public class SyncFLJob {
|
|||
|
||||
private Map<String, float[]> getFeatureMap() {
|
||||
Map<String, float[]> featureMap = new HashMap<>();
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
featureMap = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
|
@ -215,12 +215,12 @@ public class SyncFLJob {
|
|||
|
||||
public int[] modelInference(String flName, String dataPath, String vocabFile, String idsFile, String modelPath) {
|
||||
int[] labels = new int[0];
|
||||
if (flName.equals(ADBERT)) {
|
||||
AdInferBert adInferBert = AdInferBert.getInstance();
|
||||
if (flName.equals(ALBERT)) {
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
LOGGER.info(Common.addTag("===========model inference============="));
|
||||
labels = adInferBert.inferModel(modelPath, dataPath, vocabFile, idsFile);
|
||||
labels = alInferBert.inferModel(modelPath, dataPath, vocabFile, idsFile);
|
||||
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
|
||||
SessionUtil.free(adInferBert.getTrainSession());
|
||||
SessionUtil.free(alInferBert.getTrainSession());
|
||||
LOGGER.info(Common.addTag("[model inference] inference finish"));
|
||||
} else if (flName.equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
|
@ -240,18 +240,18 @@ public class SyncFLJob {
|
|||
int tag = 0;
|
||||
FLClientStatus status = FLClientStatus.SUCCESS;
|
||||
try {
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
localFLParameter.setServerMod(ServerMod.HYBRID_TRAINING.toString());
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
tag = adTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session============="));
|
||||
AdInferBert adInferBert = AdInferBert.getInstance();
|
||||
tag = adInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString());
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
|
||||
|
@ -278,13 +278,13 @@ public class SyncFLJob {
|
|||
LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + e.getMessage()));
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
LOGGER.info(Common.addTag("===========free train session============="));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
SessionUtil.free(adTrainBert.getTrainSession());
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
SessionUtil.free(alTrainBert.getTrainSession());
|
||||
LOGGER.info(Common.addTag("===========free inference session============="));
|
||||
AdInferBert adInferBert = AdInferBert.getInstance();
|
||||
SessionUtil.free(adInferBert.getTrainSession());
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
SessionUtil.free(alInferBert.getTrainSession());
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("===========free session============="));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
|
@ -376,7 +376,7 @@ public class SyncFLJob {
|
|||
flParameter.setTimeWindow(timeWindow);
|
||||
flParameter.setUseElb(useElb);
|
||||
flParameter.setServerNum(serverNum);
|
||||
if (ADBERT.equals(flName)) {
|
||||
if (ALBERT.equals(flName)) {
|
||||
flParameter.setVocabFile(vocabFile);
|
||||
flParameter.setIdsFile(idsFile);
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package com.mindspore.flclient;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import com.mindspore.flclient.model.AdTrainBert;
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import mindspore.schema.FeatureMap;
|
||||
|
@ -31,7 +31,7 @@ import java.util.HashMap;
|
|||
import java.util.Map;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ADBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
public class UpdateModel {
|
||||
|
@ -100,10 +100,10 @@ public class UpdateModel {
|
|||
case NOT_ENCRYPT:
|
||||
default:
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + flParameter.getFlName()));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession()));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
||||
if (map.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil.convertTensorToFeatures>"));
|
||||
status = FLClientStatus.FAILED;
|
||||
|
|
|
@ -25,8 +25,8 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
public class AdBert extends TrainModel {
|
||||
private static final Logger logger = Logger.getLogger(AdBert.class.toString());
|
||||
public class AlBert extends TrainModel {
|
||||
private static final Logger logger = Logger.getLogger(AlBert.class.toString());
|
||||
|
||||
private static final int NUM_OF_CLASS = 4;
|
||||
|
|
@ -21,18 +21,18 @@ import com.mindspore.flclient.Common;
|
|||
import java.util.Arrays;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
public class AdInferBert extends AdBert {
|
||||
private static final Logger logger = Logger.getLogger(AdInferBert.class.toString());
|
||||
public class AlInferBert extends AlBert {
|
||||
private static final Logger logger = Logger.getLogger(AlInferBert.class.toString());
|
||||
|
||||
private static volatile AdInferBert adInferBert;
|
||||
private static volatile AlInferBert alInferBert;
|
||||
|
||||
public static AdInferBert getInstance() {
|
||||
AdInferBert localRef = adInferBert;
|
||||
public static AlInferBert getInstance() {
|
||||
AlInferBert localRef = alInferBert;
|
||||
if (localRef == null) {
|
||||
synchronized (AdInferBert.class) {
|
||||
localRef = adInferBert;
|
||||
synchronized (AlInferBert.class) {
|
||||
localRef = alInferBert;
|
||||
if (localRef == null) {
|
||||
adInferBert = localRef = new AdInferBert();
|
||||
alInferBert = localRef = new AlInferBert();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -20,18 +20,18 @@ import com.mindspore.flclient.Common;
|
|||
|
||||
import java.util.logging.Logger;
|
||||
|
||||
public class AdTrainBert extends AdBert {
|
||||
private static final Logger logger = Logger.getLogger(AdTrainBert.class.toString());
|
||||
public class AlTrainBert extends AlBert {
|
||||
private static final Logger logger = Logger.getLogger(AlTrainBert.class.toString());
|
||||
|
||||
private static volatile AdTrainBert adTrainBert;
|
||||
private static volatile AlTrainBert alTrainBert;
|
||||
|
||||
public static AdTrainBert getInstance() {
|
||||
AdTrainBert localRef = adTrainBert;
|
||||
public static AlTrainBert getInstance() {
|
||||
AlTrainBert localRef = alTrainBert;
|
||||
if (localRef == null) {
|
||||
synchronized (AdTrainBert.class) {
|
||||
localRef = adTrainBert;
|
||||
synchronized (AlTrainBert.class) {
|
||||
localRef = alTrainBert;
|
||||
if (localRef == null) {
|
||||
adTrainBert = localRef = new AdTrainBert();
|
||||
alTrainBert = localRef = new AlTrainBert();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -27,7 +27,7 @@ import java.util.*;
|
|||
import java.util.logging.Logger;
|
||||
|
||||
public class CustomTokenizer {
|
||||
private static final Logger logger = Logger.getLogger(AdInferBert.class.toString());
|
||||
private static final Logger logger = Logger.getLogger(CustomTokenizer.class.toString());
|
||||
private Map<String, Integer> vocabs = new HashMap<>();
|
||||
private Boolean doLowerCase = Boolean.TRUE;
|
||||
private int maxInputChars = 100;
|
||||
|
|
Loading…
Reference in New Issue