From 5f8d935e1231089572efba7c43ea7ca4e06189f9 Mon Sep 17 00:00:00 2001 From: zhoushan Date: Sat, 22 Jan 2022 10:49:42 +0800 Subject: [PATCH] add DynamicInferModel tag for flclient --- .../java/com/mindspore/flclient/Common.java | 2 +- .../com/mindspore/flclient/FLLiteClient.java | 13 ++++- .../com/mindspore/flclient/FLParameter.java | 9 +++ .../mindspore/flclient/LocalFLParameter.java | 12 +++- .../com/mindspore/flclient/StartFLJob.java | 15 +---- .../com/mindspore/flclient/SyncFLJob.java | 50 ++++++++++++++--- .../mindspore/flclient/model/Callback.java | 2 +- .../com/mindspore/flclient/model/Client.java | 55 +++++++++---------- .../client/run_client_x86.py | 9 +++ 9 files changed, 110 insertions(+), 57 deletions(-) diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/Common.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/Common.java index d2fd875055d..77d26dbb89f 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/Common.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/Common.java @@ -519,7 +519,7 @@ public class Common { LOGGER.info(Common.addTag("==========Loading model, " + modelPath + " Create " + " Session=============")); Client client = ClientManager.getClient(flParameter.getFlName()); - tag = client.initSessionAndInputs(modelPath, localFLParameter.getMsConfig()); + tag = client.initSessionAndInputs(modelPath, localFLParameter.getMsConfig(), flParameter.getInputShape()); if (!Status.SUCCESS.equals(tag)) { LOGGER.severe(Common.addTag("[initSession] unsolved error code in : the return " + "is -1")); diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java index 574c8654648..a36556422e6 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java @@ -315,7 +315,14 @@ public class FLLiteClient { } retCode = ResponseCode.SUCCEED; LOGGER.info(Common.addTag("[train] train in " + flParameter.getFlName())); - Status tag = client.trainModel(epochs); + LOGGER.info(Common.addTag("[train] lr for client is: " + localFLParameter.getLr())); + Status tag = client.setLearningRate(localFLParameter.getLr()); + if (!Status.SUCCESS.equals(tag)) { + LOGGER.severe(Common.addTag("[train] setLearningRate failed, return -1, please check")); + retCode = ResponseCode.RequestError; + return FLClientStatus.FAILED; + } + tag = client.trainModel(epochs); if (!Status.SUCCESS.equals(tag)) { failed("[train] unsolved error code in ", ResponseCode.RequestError); } @@ -568,12 +575,12 @@ public class FLLiteClient { float acc = 0; if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod())); - client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig()); + client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig(), flParameter.getInputShape()); LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath())); acc = client.evalModel(); } else { LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod())); - client.initSessionAndInputs(flParameter.getTrainModelPath(), localFLParameter.getMsConfig()); + client.initSessionAndInputs(flParameter.getTrainModelPath(), localFLParameter.getMsConfig(), flParameter.getInputShape()); LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getTrainModelPath())); acc = client.evalModel(); } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLParameter.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLParameter.java index 8e765226224..828c3333751 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLParameter.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLParameter.java @@ -83,6 +83,7 @@ public class FLParameter { private Map> dataMap = new HashMap<>(); private ServerMod serverMod; private int batchSize; + private int[][] inputShape; private FLParameter() { clientID = UUID.randomUUID().toString(); @@ -575,4 +576,12 @@ public class FLParameter { } this.batchSize = batchSize; } + + public int[][] getInputShape() { + return inputShape; + } + + public void setInputShape(int[][] inputShape) { + this.inputShape = inputShape; + } } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java index 59a6cbc4d3b..2234f5f882a 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/LocalFLParameter.java @@ -82,6 +82,8 @@ public class LocalFLParameter { private boolean stopJobFlag = false; private MSConfig msConfig = new MSConfig(); private boolean useSSL = true; + private float lr = 0.1f; + private LocalFLParameter() { // set classifierWeightName albertWeightName @@ -232,7 +234,6 @@ public class LocalFLParameter { msConfig.init(DeviceType, threadNum, cpuBindMode, enable_fp16); } - public boolean isUseSSL() { return useSSL; } @@ -240,4 +241,13 @@ public class LocalFLParameter { public void setUseSSL(boolean useSSL) { this.useSSL = useSSL; } + + public float getLr() { + LOGGER.severe(Common.addTag("[localFLParameter] the parameter of from server is: " + lr)); + return lr; + } + + public void setLr(float lr) { + this.lr = lr; + } } diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/StartFLJob.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/StartFLJob.java index 47953fb9375..8697e5175ac 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/StartFLJob.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/StartFLJob.java @@ -305,13 +305,6 @@ public class StartFLJob { retCode = ResponseCode.RequestError; return status; } - LOGGER.info(Common.addTag("[startFLJob] set for client: " + lr)); - tag = client.setLearningRate(lr); - if (!Status.SUCCESS.equals(tag)) { - LOGGER.severe(Common.addTag("[startFLJob] setLearningRate failed, return -1, please check")); - retCode = ResponseCode.RequestError; - return FLClientStatus.FAILED; - } LOGGER.info(Common.addTag("[startFLJob] set for client: " + batchSize)); client.setBatchSize(batchSize); tag = client.updateFeatures(flParameter.getTrainModelPath(), trainFeatureMaps); @@ -351,13 +344,6 @@ public class StartFLJob { retCode = ResponseCode.RequestError; return status; } - LOGGER.info(Common.addTag("[startFLJob] set for client: " + lr)); - tag = client.setLearningRate(lr); - if (!Status.SUCCESS.equals(tag)) { - LOGGER.severe(Common.addTag("[startFLJob] setLearningRate failed, return -1, please check")); - retCode = ResponseCode.RequestError; - return FLClientStatus.FAILED; - } LOGGER.info(Common.addTag("[startFLJob] set for client: " + batchSize)); client.setBatchSize(batchSize); tag = client.updateFeatures(flParameter.getTrainModelPath(), featureMaps); @@ -452,6 +438,7 @@ public class StartFLJob { "valid, " + "will use the default value 0.1")); } + localFLParameter.setLr(lr); batchSize = flPlanConfig.miniBatch(); if (Common.checkFLName(flParameter.getFlName())) { status = deprecatedParseFeatures(flJob); diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java index 6cdc305ca7e..e8024c0dac5 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java @@ -35,6 +35,7 @@ import mindspore.schema.ResponseGetModel; import java.nio.ByteBuffer; import java.security.SecureRandom; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -278,32 +279,47 @@ public class SyncFLJob { */ public int[] modelInference() { if (Common.checkFLName(flParameter.getFlName())) { + LOGGER.warning(Common.addTag(Common.LOG_DEPRECATED)); return deprecatedModelInference(); } + return new int[0]; + } + /** + * Starts an inference task on the device. + * + * @return the status code corresponding to the response message. + */ + public List modelInfer() { Client client = ClientManager.getClient(flParameter.getFlName()); localFLParameter.setMsConfig(0, flParameter.getThreadNum(), flParameter.getCpuBindMode(), false); localFLParameter.setStopJobFlag(false); - int[] labels = new int[0]; + if (!(null == flParameter.getInputShape())) { + LOGGER.info("[model inference] the inference model has dynamic input."); + } Map dataSize = client.initDataSets(flParameter.getDataMap()); if (dataSize.isEmpty()) { LOGGER.severe("[model inference] initDataSets failed, please check"); - return new int[0]; + client.free(); + return null; } - Status tag = client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig()); + Status tag = client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig(), flParameter.getInputShape()); if (!Status.SUCCESS.equals(tag)) { LOGGER.severe(Common.addTag("[model inference] unsolved error code in : the return " + " status is: " + tag)); - return new int[0]; + client.free(); + return null; } client.setBatchSize(flParameter.getBatchSize()); LOGGER.info(Common.addTag("===========model inference=============")); - labels = client.inferModel().stream().mapToInt(Integer::valueOf).toArray(); - if (labels == null || labels.length == 0) { + List labels = client.inferModel(); + if (labels == null || labels.size() == 0) { LOGGER.severe("[model inference] the returned label from client.inferModel() is null, please " + "check"); + client.free(); + return null; } - LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels))); + LOGGER.info(Common.addTag("[model inference] the predicted outputs: " + Arrays.deepToString(labels.toArray()))); client.free(); LOGGER.info(Common.addTag("[model inference] inference finish")); return labels; @@ -416,6 +432,18 @@ public class SyncFLJob { } } + private static int[][] getInputShapeArray(String inputShape) { + String[] inputs = inputShape.split(";"); + int inputsSize = inputs.length; + int[][] inputsArray = new int[inputsSize][]; + for (int i = 0; i < inputsSize; i++) { + String[] input = inputs[i].split(","); + int[] inputArray = Arrays.stream(input).mapToInt(Integer::parseInt).toArray(); + inputsArray[i] = inputArray; + } + return inputsArray; + } + private static void task(String[] args) { String trainDataPath = args[0]; String evalDataPath = args[1]; @@ -438,10 +466,14 @@ public class SyncFLJob { String inferWeightName = args[17]; String nameRegex = args[18]; String serverMod = args[19]; + String inputShape = args[21]; int batchSize = Integer.parseInt(args[20]); - FLParameter flParameter = FLParameter.getInstance(); + if (!("null".equals(inputShape) || inputShape == null)) { + flParameter.setInputShape(getInputShapeArray(inputShape)); + } + // create dataset of map Map> dataMap = createDatasetMap(trainDataPath, evalDataPath, inferDataPath, pathRegex); @@ -476,7 +508,7 @@ public class SyncFLJob { flParameter.setThreadNum(threadNum); flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode)); flParameter.setBatchSize(batchSize); - syncFLJob.modelInference(); + syncFLJob.modelInfer(); break; case "getModel": LOGGER.info(Common.addTag("start syncFLJob.getModel()")); diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/Callback.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/Callback.java index f9daf0ad5a7..bde941b39be 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/Callback.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/Callback.java @@ -32,7 +32,7 @@ import java.util.logging.Logger; public abstract class Callback { private static final Logger logger = Logger.getLogger(LossCallback.class.toString()); - LiteSession session; + protected LiteSession session; public int steps = 0; diff --git a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/Client.java b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/Client.java index adf7db3ee0c..593481d037e 100644 --- a/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/Client.java +++ b/mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/Client.java @@ -28,6 +28,7 @@ import mindspore.schema.FeatureMap; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -53,8 +54,6 @@ public abstract class Client { */ public Map dataSets = new HashMap<>(); - private boolean isDynamicInferModel = false; - private final List inputsBuffer = new ArrayList<>(); /** @@ -88,7 +87,7 @@ public abstract class Client { * @param inferCallback callback used for infer model. * @return infer result. */ - public abstract List getInferResult(List inferCallback); + public abstract List getInferResult(List inferCallback); /** * Init lite session and inputs buffer. @@ -97,26 +96,37 @@ public abstract class Client { * @param config session run config. * @return execute status. */ - public Status initSessionAndInputs(String modelPath, MSConfig config) { + public Status initSessionAndInputs(String modelPath, MSConfig config, int[][] inputShapes) { if (modelPath == null) { logger.severe(Common.addTag("session init failed")); - return Status.NULLPTR; + return Status.FAILED; } - if (isDynamicInferModel) { - logger.info(Common.addTag(modelPath + " is dynamic input")); - } - Optional optTrainSession = initSession(modelPath, config); + Optional optTrainSession = initSession(modelPath, config, inputShapes != null); if (!optTrainSession.isPresent()) { logger.severe(Common.addTag("session init failed")); - return Status.NULLPTR; + return Status.FAILED; } trainSession = optTrainSession.get(); inputsBuffer.clear(); - List inputs = trainSession.getInputs(); - for (MSTensor input : inputs) { - ByteBuffer inputBuffer = ByteBuffer.allocateDirect((int) input.size()); - inputBuffer.order(ByteOrder.nativeOrder()); - inputsBuffer.add(inputBuffer); + if (inputShapes == null) { + List inputs = trainSession.getInputs(); + for (MSTensor input : inputs) { + ByteBuffer inputBuffer = ByteBuffer.allocateDirect((int) input.size()); + inputBuffer.order(ByteOrder.nativeOrder()); + inputsBuffer.add(inputBuffer); + } + } else { + boolean isSuccess = trainSession.resize(trainSession.getInputs(), inputShapes); + if (!isSuccess) { + logger.severe(Common.addTag("session resize failed")); + return Status.FAILED; + } + for (int[] shapes : inputShapes) { + int size = IntStream.of(shapes).reduce((a, b) -> a * b).getAsInt() * Integer.BYTES; + ByteBuffer inputBuffer = ByteBuffer.allocateDirect(size); + inputBuffer.order(ByteOrder.nativeOrder()); + inputsBuffer.add(inputBuffer); + } } return Status.SUCCESS; } @@ -188,7 +198,7 @@ public abstract class Client { * * @return infer status. */ - public List inferModel() { + public List inferModel() { boolean isSuccess = trainSession.eval(); if (!isSuccess) { logger.severe(Common.addTag("train session switch eval mode failed")); @@ -252,7 +262,7 @@ public abstract class Client { return Status.SUCCESS; } - private Optional initSession(String modelPath, MSConfig msConfig) { + private Optional initSession(String modelPath, MSConfig msConfig, boolean isDynamicInferModel) { if (modelPath == null) { logger.severe(Common.addTag("modelPath cannot be empty")); return Optional.empty(); @@ -379,15 +389,4 @@ public abstract class Client { dataset.batchSize = batchSize; } } - - /** - * Resize client input shape dims. - * - * @param dims new input shapes. - * @return resize status. - */ - public boolean resize(int[][] dims) { - isDynamicInferModel = true; - return trainSession.resize(trainSession.getInputs(), dims); - } } \ No newline at end of file diff --git a/tests/st/fl/cross_device_lenet/client/run_client_x86.py b/tests/st/fl/cross_device_lenet/client/run_client_x86.py index 4e778061199..c023e3893da 100644 --- a/tests/st/fl/cross_device_lenet/client/run_client_x86.py +++ b/tests/st/fl/cross_device_lenet/client/run_client_x86.py @@ -116,6 +116,12 @@ def get_parser(): - client_num Specifies the number of clients. The value must be the same as that of 'start_fl_job_cnt' when the server is started. This parameter is not required in actual scenarios. + - input_shape + string type, this parameter is used to set the real input shape of the model, where the value corresponding to + the batch size in the inputs shape should be consistent with the parameter 'batch_size'. When the model contains + only one input, use a comma ',' to connect the dimensions of each dimension, such as "32,64,96"; when the model + contains multiple inputs, use a semicolon ';' to connect the shapes corresponding to different inputs, such as + "32,64,96;12,24,26". Notice, this parameter is only required when the inference model supports dynamic inputs. """ parser = argparse.ArgumentParser(description="Run SyncFLJob.java case") @@ -150,6 +156,7 @@ def get_parser(): parser.add_argument("--name_regex", type=str, default=",") parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING") parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--input_shape", type=str, default="null") parser.add_argument("--client_num", type=int, default=0) return parser @@ -188,6 +195,7 @@ infer_weight_name = args.infer_weight_name name_regex = args.name_regex server_mode = args.server_mode batch_size = args.batch_size +input_shape = args.input_shape client_num = args.client_num @@ -274,6 +282,7 @@ for i in range(client_num): cmd_client += name_regex + " " cmd_client += server_mode + " " cmd_client += str(batch_size) + " " + cmd_client += input_shape + " " cmd_client += " > client" + ".log 2>&1 &" print(cmd_client) subprocess.call(['bash', '-c', cmd_client])