From b39d64e2abc11018637f902677c1bd999cc33ab3 Mon Sep 17 00:00:00 2001 From: zhoushan Date: Thu, 1 Jul 2021 12:02:58 +0800 Subject: [PATCH] change version for compiled packages --- .../lite/java/java/fl_client/build.gradle | 6 +- .../com/mindspore/flclient/SyncFLJob.java | 55 ++++++++++++------- 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/mindspore/lite/java/java/fl_client/build.gradle b/mindspore/lite/java/java/fl_client/build.gradle index defa0a5fc42..24e0df2807d 100644 --- a/mindspore/lite/java/java/fl_client/build.gradle +++ b/mindspore/lite/java/java/fl_client/build.gradle @@ -45,13 +45,13 @@ dependencies { testImplementation 'junit:junit:4.12' // https://mvnrepository.com/artifact/com.squareup.okhttp3/okhttp - compile group: 'com.squareup.okhttp3', name: 'okhttp', version: '3.14.4' - testCompile group: 'com.squareup.okhttp3', name: 'mockwebserver', version: '3.14.4' + compile group: 'com.squareup.okhttp3', name: 'okhttp', version: '3.14.9' + testCompile group: 'com.squareup.okhttp3', name: 'mockwebserver', version: '3.14.9' // https://mvnrepository.com/artifact/com.google.flatbuffers/flatbuffers-java compile group: 'com.google.flatbuffers', name: 'flatbuffers-java', version: '1.11.0' - compile(group: 'org.bouncycastle',name: 'bcprov-jdk15on', version: '1.63') + compile(group: 'org.bouncycastle',name: 'bcprov-jdk15on', version: '1.68') implementation project(':common') implementation project(':linux_x86') diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java index 349dcd3a262..8f961a1ada4 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java @@ -23,9 +23,9 @@ import mindspore.schema.ResponseGetModel; import java.nio.ByteBuffer; import java.util.Arrays; -import java.util.logging.Logger; -import java.util.Map; import java.util.HashMap; +import java.util.Map; +import java.util.logging.Logger; import static com.mindspore.flclient.FLParameter.SLEEP_TIME; import static com.mindspore.flclient.LocalFLParameter.ADBERT; @@ -224,40 +224,38 @@ public class SyncFLJob { return labels; } - public FLClientStatus getModel(boolean useElb, int serverNum, String ip, int port, String flName, String trainModelPath, String inferModelPath, boolean useSSL) { + public FLClientStatus getModel() { int tag = 0; - flParameter.setTrainModelPath(trainModelPath); - flParameter.setInferModelPath(inferModelPath); FLClientStatus status = FLClientStatus.SUCCESS; try { - if (flName.equals(ADBERT)) { + if (flParameter.getFlName().equals(ADBERT)) { localFLParameter.setServerMod(ServerMod.HYBRID_TRAINING.toString()); - LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + trainModelPath + " Create Train Session=============")); + LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session=============")); AdTrainBert adTrainBert = AdTrainBert.getInstance(); - tag = adTrainBert.initSessionAndInputs(trainModelPath, true); + tag = adTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true); if (tag == -1) { LOGGER.severe(Common.addTag("[initSession] unsolved error code in : the return is -1")); return FLClientStatus.FAILED; } - LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " + inferModelPath + " Create inference Session=============")); + LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session=============")); AdInferBert adInferBert = AdInferBert.getInstance(); - tag = adInferBert.initSessionAndInputs(inferModelPath, false); - } else if (flName.equals(LENET)) { + tag = adInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false); + } else if (flParameter.getFlName().equals(LENET)) { localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString()); - LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + trainModelPath + " Create Train Session=============")); + LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session=============")); TrainLenet trainLenet = TrainLenet.getInstance(); - tag = trainLenet.initSessionAndInputs(trainModelPath, true); + tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true); } if (tag == -1) { LOGGER.severe(Common.addTag("[initSession] unsolved error code in : the return is -1")); return FLClientStatus.FAILED; } - flParameter.setUseSSL(useSSL); + flParameter.setUseSSL(flParameter.isUseSSL()); FLCommunication flCommunication = FLCommunication.getInstance(); - String url = Common.generateUrl(flParameter.isUseHttps(), useElb, ip, port, serverNum); + String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum()); LOGGER.info(Common.addTag("[getModel] ===========getModel url: " + url + "==============")); GetModel getModelBuf = GetModel.getInstance(); - byte[] buffer = getModelBuf.getRequestGetModel(flName, 0); + byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), 0); byte[] message = flCommunication.syncRequest(url + "/getModel", buffer); LOGGER.info(Common.addTag("[getModel] get model request success")); ByteBuffer debugBuffer = ByteBuffer.wrap(message); @@ -268,14 +266,14 @@ public class SyncFLJob { LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + e.getMessage())); status = FLClientStatus.FAILED; } - if (flName.equals(ADBERT)) { + if (flParameter.getFlName().equals(ADBERT)) { LOGGER.info(Common.addTag("===========free train session=============")); AdTrainBert adTrainBert = AdTrainBert.getInstance(); SessionUtil.free(adTrainBert.getTrainSession()); LOGGER.info(Common.addTag("===========free inference session=============")); AdInferBert adInferBert = AdInferBert.getInstance(); SessionUtil.free(adInferBert.getTrainSession()); - } else if (flName.equals(LENET)) { + } else if (flParameter.getFlName().equals(LENET)) { LOGGER.info(Common.addTag("===========free session=============")); TrainLenet trainLenet = TrainLenet.getInstance(); SessionUtil.free(trainLenet.getTrainSession()); @@ -327,7 +325,8 @@ public class SyncFLJob { boolean useElb = Boolean.parseBoolean(args[12]); int serverNum = Integer.parseInt(args[13]); boolean useHttps = Boolean.parseBoolean(args[14]); - String task = args[15]; + String certPath = args[15]; + String task = args[16]; FLParameter flParameter = FLParameter.getInstance(); LOGGER.info(Common.addTag("[args] trainDataset: " + trainDataset)); @@ -345,12 +344,15 @@ public class SyncFLJob { LOGGER.info(Common.addTag("[args] useElb: " + useElb)); LOGGER.info(Common.addTag("[args] serverNum: " + serverNum)); LOGGER.info(Common.addTag("[args] useHttps: " + useHttps)); + LOGGER.info(Common.addTag("[args] certPath: " + certPath)); LOGGER.info(Common.addTag("[args] task: " + task)); flParameter.setClientID(clientID); - flParameter.setUseHttps(useHttps); SyncFLJob syncFLJob = new SyncFLJob(); if (task.equals("train")) { + flParameter.setUseHttps(useHttps); + flParameter.setCertPath(certPath); + flParameter.setHostName(ip); flParameter.setTrainDataset(trainDataset); flParameter.setFlName(flName); flParameter.setTrainModelPath(trainModelPath); @@ -370,7 +372,18 @@ public class SyncFLJob { } else if (task.equals("inference")) { syncFLJob.modelInference(flName, testDataset, vocabFile, idsFile, inferModelPath); } else if (task.equals("getModel")) { - syncFLJob.getModel(false, 1, ip, port, flName, trainModelPath, inferModelPath, false); + flParameter.setUseHttps(useHttps); + flParameter.setCertPath(certPath); + flParameter.setHostName(ip); + flParameter.setFlName(flName); + flParameter.setTrainModelPath(trainModelPath); + flParameter.setInferModelPath(inferModelPath); + flParameter.setIp(ip); + flParameter.setUseSSL(useSSL); + flParameter.setPort(port); + flParameter.setUseElb(useElb); + flParameter.setServerNum(serverNum); + syncFLJob.getModel(); } else { LOGGER.info(Common.addTag("do not do any thing!")); }