!19209 change version for compiled packages

Merge pull request !19209 from zhoushan33/flclient0701_om
This commit is contained in:
i-robot 2021-07-01 08:24:56 +00:00 committed by Gitee
commit 0c636df5a1
2 changed files with 37 additions and 24 deletions

View File

@ -45,13 +45,13 @@ dependencies {
testImplementation 'junit:junit:4.12'
// https://mvnrepository.com/artifact/com.squareup.okhttp3/okhttp
compile group: 'com.squareup.okhttp3', name: 'okhttp', version: '3.14.4'
testCompile group: 'com.squareup.okhttp3', name: 'mockwebserver', version: '3.14.4'
compile group: 'com.squareup.okhttp3', name: 'okhttp', version: '3.14.9'
testCompile group: 'com.squareup.okhttp3', name: 'mockwebserver', version: '3.14.9'
// https://mvnrepository.com/artifact/com.google.flatbuffers/flatbuffers-java
compile group: 'com.google.flatbuffers', name: 'flatbuffers-java', version: '1.11.0'
compile(group: 'org.bouncycastle',name: 'bcprov-jdk15on', version: '1.63')
compile(group: 'org.bouncycastle',name: 'bcprov-jdk15on', version: '1.68')
implementation project(':common')
implementation project(':linux_x86')

View File

@ -23,9 +23,9 @@ import mindspore.schema.ResponseGetModel;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.logging.Logger;
import java.util.Map;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.ADBERT;
@ -224,40 +224,38 @@ public class SyncFLJob {
return labels;
}
public FLClientStatus getModel(boolean useElb, int serverNum, String ip, int port, String flName, String trainModelPath, String inferModelPath, boolean useSSL) {
public FLClientStatus getModel() {
int tag = 0;
flParameter.setTrainModelPath(trainModelPath);
flParameter.setInferModelPath(inferModelPath);
FLClientStatus status = FLClientStatus.SUCCESS;
try {
if (flName.equals(ADBERT)) {
if (flParameter.getFlName().equals(ADBERT)) {
localFLParameter.setServerMod(ServerMod.HYBRID_TRAINING.toString());
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + trainModelPath + " Create Train Session============="));
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
tag = adTrainBert.initSessionAndInputs(trainModelPath, true);
tag = adTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " + inferModelPath + " Create inference Session============="));
LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session============="));
AdInferBert adInferBert = AdInferBert.getInstance();
tag = adInferBert.initSessionAndInputs(inferModelPath, false);
} else if (flName.equals(LENET)) {
tag = adInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
} else if (flParameter.getFlName().equals(LENET)) {
localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString());
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + trainModelPath + " Create Train Session============="));
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
TrainLenet trainLenet = TrainLenet.getInstance();
tag = trainLenet.initSessionAndInputs(trainModelPath, true);
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
}
if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
return FLClientStatus.FAILED;
}
flParameter.setUseSSL(useSSL);
flParameter.setUseSSL(flParameter.isUseSSL());
FLCommunication flCommunication = FLCommunication.getInstance();
String url = Common.generateUrl(flParameter.isUseHttps(), useElb, ip, port, serverNum);
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[getModel] ===========getModel url: " + url + "=============="));
GetModel getModelBuf = GetModel.getInstance();
byte[] buffer = getModelBuf.getRequestGetModel(flName, 0);
byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), 0);
byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
LOGGER.info(Common.addTag("[getModel] get model request success"));
ByteBuffer debugBuffer = ByteBuffer.wrap(message);
@ -268,14 +266,14 @@ public class SyncFLJob {
LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + e.getMessage()));
status = FLClientStatus.FAILED;
}
if (flName.equals(ADBERT)) {
if (flParameter.getFlName().equals(ADBERT)) {
LOGGER.info(Common.addTag("===========free train session============="));
AdTrainBert adTrainBert = AdTrainBert.getInstance();
SessionUtil.free(adTrainBert.getTrainSession());
LOGGER.info(Common.addTag("===========free inference session============="));
AdInferBert adInferBert = AdInferBert.getInstance();
SessionUtil.free(adInferBert.getTrainSession());
} else if (flName.equals(LENET)) {
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("===========free session============="));
TrainLenet trainLenet = TrainLenet.getInstance();
SessionUtil.free(trainLenet.getTrainSession());
@ -327,7 +325,8 @@ public class SyncFLJob {
boolean useElb = Boolean.parseBoolean(args[12]);
int serverNum = Integer.parseInt(args[13]);
boolean useHttps = Boolean.parseBoolean(args[14]);
String task = args[15];
String certPath = args[15];
String task = args[16];
FLParameter flParameter = FLParameter.getInstance();
LOGGER.info(Common.addTag("[args] trainDataset: " + trainDataset));
@ -345,12 +344,15 @@ public class SyncFLJob {
LOGGER.info(Common.addTag("[args] useElb: " + useElb));
LOGGER.info(Common.addTag("[args] serverNum: " + serverNum));
LOGGER.info(Common.addTag("[args] useHttps: " + useHttps));
LOGGER.info(Common.addTag("[args] certPath: " + certPath));
LOGGER.info(Common.addTag("[args] task: " + task));
flParameter.setClientID(clientID);
flParameter.setUseHttps(useHttps);
SyncFLJob syncFLJob = new SyncFLJob();
if (task.equals("train")) {
flParameter.setUseHttps(useHttps);
flParameter.setCertPath(certPath);
flParameter.setHostName(ip);
flParameter.setTrainDataset(trainDataset);
flParameter.setFlName(flName);
flParameter.setTrainModelPath(trainModelPath);
@ -370,7 +372,18 @@ public class SyncFLJob {
} else if (task.equals("inference")) {
syncFLJob.modelInference(flName, testDataset, vocabFile, idsFile, inferModelPath);
} else if (task.equals("getModel")) {
syncFLJob.getModel(false, 1, ip, port, flName, trainModelPath, inferModelPath, false);
flParameter.setUseHttps(useHttps);
flParameter.setCertPath(certPath);
flParameter.setHostName(ip);
flParameter.setFlName(flName);
flParameter.setTrainModelPath(trainModelPath);
flParameter.setInferModelPath(inferModelPath);
flParameter.setIp(ip);
flParameter.setUseSSL(useSSL);
flParameter.setPort(port);
flParameter.setUseElb(useElb);
flParameter.setServerNum(serverNum);
syncFLJob.getModel();
} else {
LOGGER.info(Common.addTag("do not do any thing!"));
}