!19209 change version for compiled packages
Merge pull request !19209 from zhoushan33/flclient0701_om
This commit is contained in:
commit
0c636df5a1
|
@ -45,13 +45,13 @@ dependencies {
|
|||
testImplementation 'junit:junit:4.12'
|
||||
|
||||
// https://mvnrepository.com/artifact/com.squareup.okhttp3/okhttp
|
||||
compile group: 'com.squareup.okhttp3', name: 'okhttp', version: '3.14.4'
|
||||
testCompile group: 'com.squareup.okhttp3', name: 'mockwebserver', version: '3.14.4'
|
||||
compile group: 'com.squareup.okhttp3', name: 'okhttp', version: '3.14.9'
|
||||
testCompile group: 'com.squareup.okhttp3', name: 'mockwebserver', version: '3.14.9'
|
||||
|
||||
// https://mvnrepository.com/artifact/com.google.flatbuffers/flatbuffers-java
|
||||
compile group: 'com.google.flatbuffers', name: 'flatbuffers-java', version: '1.11.0'
|
||||
|
||||
compile(group: 'org.bouncycastle',name: 'bcprov-jdk15on', version: '1.63')
|
||||
compile(group: 'org.bouncycastle',name: 'bcprov-jdk15on', version: '1.68')
|
||||
implementation project(':common')
|
||||
implementation project(':linux_x86')
|
||||
|
||||
|
|
|
@ -23,9 +23,9 @@ import mindspore.schema.ResponseGetModel;
|
|||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Arrays;
|
||||
import java.util.logging.Logger;
|
||||
import java.util.Map;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||
import static com.mindspore.flclient.LocalFLParameter.ADBERT;
|
||||
|
@ -224,40 +224,38 @@ public class SyncFLJob {
|
|||
return labels;
|
||||
}
|
||||
|
||||
public FLClientStatus getModel(boolean useElb, int serverNum, String ip, int port, String flName, String trainModelPath, String inferModelPath, boolean useSSL) {
|
||||
public FLClientStatus getModel() {
|
||||
int tag = 0;
|
||||
flParameter.setTrainModelPath(trainModelPath);
|
||||
flParameter.setInferModelPath(inferModelPath);
|
||||
FLClientStatus status = FLClientStatus.SUCCESS;
|
||||
try {
|
||||
if (flName.equals(ADBERT)) {
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
localFLParameter.setServerMod(ServerMod.HYBRID_TRAINING.toString());
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + trainModelPath + " Create Train Session============="));
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
tag = adTrainBert.initSessionAndInputs(trainModelPath, true);
|
||||
tag = adTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " + inferModelPath + " Create inference Session============="));
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session============="));
|
||||
AdInferBert adInferBert = AdInferBert.getInstance();
|
||||
tag = adInferBert.initSessionAndInputs(inferModelPath, false);
|
||||
} else if (flName.equals(LENET)) {
|
||||
tag = adInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString());
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + trainModelPath + " Create Train Session============="));
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
tag = trainLenet.initSessionAndInputs(trainModelPath, true);
|
||||
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||
}
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
flParameter.setUseSSL(useSSL);
|
||||
flParameter.setUseSSL(flParameter.isUseSSL());
|
||||
FLCommunication flCommunication = FLCommunication.getInstance();
|
||||
String url = Common.generateUrl(flParameter.isUseHttps(), useElb, ip, port, serverNum);
|
||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
LOGGER.info(Common.addTag("[getModel] ===========getModel url: " + url + "=============="));
|
||||
GetModel getModelBuf = GetModel.getInstance();
|
||||
byte[] buffer = getModelBuf.getRequestGetModel(flName, 0);
|
||||
byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), 0);
|
||||
byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
|
||||
LOGGER.info(Common.addTag("[getModel] get model request success"));
|
||||
ByteBuffer debugBuffer = ByteBuffer.wrap(message);
|
||||
|
@ -268,14 +266,14 @@ public class SyncFLJob {
|
|||
LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + e.getMessage()));
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
if (flName.equals(ADBERT)) {
|
||||
if (flParameter.getFlName().equals(ADBERT)) {
|
||||
LOGGER.info(Common.addTag("===========free train session============="));
|
||||
AdTrainBert adTrainBert = AdTrainBert.getInstance();
|
||||
SessionUtil.free(adTrainBert.getTrainSession());
|
||||
LOGGER.info(Common.addTag("===========free inference session============="));
|
||||
AdInferBert adInferBert = AdInferBert.getInstance();
|
||||
SessionUtil.free(adInferBert.getTrainSession());
|
||||
} else if (flName.equals(LENET)) {
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("===========free session============="));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
SessionUtil.free(trainLenet.getTrainSession());
|
||||
|
@ -327,7 +325,8 @@ public class SyncFLJob {
|
|||
boolean useElb = Boolean.parseBoolean(args[12]);
|
||||
int serverNum = Integer.parseInt(args[13]);
|
||||
boolean useHttps = Boolean.parseBoolean(args[14]);
|
||||
String task = args[15];
|
||||
String certPath = args[15];
|
||||
String task = args[16];
|
||||
|
||||
FLParameter flParameter = FLParameter.getInstance();
|
||||
LOGGER.info(Common.addTag("[args] trainDataset: " + trainDataset));
|
||||
|
@ -345,12 +344,15 @@ public class SyncFLJob {
|
|||
LOGGER.info(Common.addTag("[args] useElb: " + useElb));
|
||||
LOGGER.info(Common.addTag("[args] serverNum: " + serverNum));
|
||||
LOGGER.info(Common.addTag("[args] useHttps: " + useHttps));
|
||||
LOGGER.info(Common.addTag("[args] certPath: " + certPath));
|
||||
LOGGER.info(Common.addTag("[args] task: " + task));
|
||||
|
||||
flParameter.setClientID(clientID);
|
||||
flParameter.setUseHttps(useHttps);
|
||||
SyncFLJob syncFLJob = new SyncFLJob();
|
||||
if (task.equals("train")) {
|
||||
flParameter.setUseHttps(useHttps);
|
||||
flParameter.setCertPath(certPath);
|
||||
flParameter.setHostName(ip);
|
||||
flParameter.setTrainDataset(trainDataset);
|
||||
flParameter.setFlName(flName);
|
||||
flParameter.setTrainModelPath(trainModelPath);
|
||||
|
@ -370,7 +372,18 @@ public class SyncFLJob {
|
|||
} else if (task.equals("inference")) {
|
||||
syncFLJob.modelInference(flName, testDataset, vocabFile, idsFile, inferModelPath);
|
||||
} else if (task.equals("getModel")) {
|
||||
syncFLJob.getModel(false, 1, ip, port, flName, trainModelPath, inferModelPath, false);
|
||||
flParameter.setUseHttps(useHttps);
|
||||
flParameter.setCertPath(certPath);
|
||||
flParameter.setHostName(ip);
|
||||
flParameter.setFlName(flName);
|
||||
flParameter.setTrainModelPath(trainModelPath);
|
||||
flParameter.setInferModelPath(inferModelPath);
|
||||
flParameter.setIp(ip);
|
||||
flParameter.setUseSSL(useSSL);
|
||||
flParameter.setPort(port);
|
||||
flParameter.setUseElb(useElb);
|
||||
flParameter.setServerNum(serverNum);
|
||||
syncFLJob.getModel();
|
||||
} else {
|
||||
LOGGER.info(Common.addTag("do not do any thing!"));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue