add DynamicInferModel tag for flclient in r1.6

This commit is contained in:
zhoushan 2022-01-26 10:12:00 +08:00
parent 76bca2addc
commit 15ec4f5af2
9 changed files with 144 additions and 47 deletions

View File

@ -519,7 +519,7 @@ public class Common {
LOGGER.info(Common.addTag("==========Loading model, " + modelPath + " Create " +
" Session============="));
Client client = ClientManager.getClient(flParameter.getFlName());
tag = client.initSessionAndInputs(modelPath, localFLParameter.getMsConfig());
tag = client.initSessionAndInputs(modelPath, localFLParameter.getMsConfig(), flParameter.getInputShape());
if (!Status.SUCCESS.equals(tag)) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return " +
"is -1"));

View File

@ -315,7 +315,14 @@ public class FLLiteClient {
}
retCode = ResponseCode.SUCCEED;
LOGGER.info(Common.addTag("[train] train in " + flParameter.getFlName()));
Status tag = client.trainModel(epochs);
LOGGER.info(Common.addTag("[train] lr for client is: " + localFLParameter.getLr()));
Status tag = client.setLearningRate(localFLParameter.getLr());
if (!Status.SUCCESS.equals(tag)) {
LOGGER.severe(Common.addTag("[train] setLearningRate failed, return -1, please check"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
tag = client.trainModel(epochs);
if (!Status.SUCCESS.equals(tag)) {
failed("[train] unsolved error code in <client.trainModel>", ResponseCode.RequestError);
}
@ -568,12 +575,12 @@ public class FLLiteClient {
float acc = 0;
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig());
client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig(), flParameter.getInputShape());
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath()));
acc = client.evalModel();
} else {
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
client.initSessionAndInputs(flParameter.getTrainModelPath(), localFLParameter.getMsConfig());
client.initSessionAndInputs(flParameter.getTrainModelPath(), localFLParameter.getMsConfig(), flParameter.getInputShape());
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getTrainModelPath()));
acc = client.evalModel();
}

View File

@ -83,6 +83,7 @@ public class FLParameter {
private Map<RunType, List<String>> dataMap = new HashMap<>();
private ServerMod serverMod;
private int batchSize;
private int[][] inputShape;
private FLParameter() {
clientID = UUID.randomUUID().toString();
@ -575,4 +576,12 @@ public class FLParameter {
}
this.batchSize = batchSize;
}
public int[][] getInputShape() {
return inputShape;
}
public void setInputShape(int[][] inputShape) {
this.inputShape = inputShape;
}
}

View File

@ -82,6 +82,8 @@ public class LocalFLParameter {
private boolean stopJobFlag = false;
private MSConfig msConfig = new MSConfig();
private boolean useSSL = true;
private float lr = 0.1f;
private LocalFLParameter() {
// set classifierWeightName albertWeightName
@ -232,7 +234,6 @@ public class LocalFLParameter {
msConfig.init(DeviceType, threadNum, cpuBindMode, enable_fp16);
}
public boolean isUseSSL() {
return useSSL;
}
@ -240,4 +241,13 @@ public class LocalFLParameter {
public void setUseSSL(boolean useSSL) {
this.useSSL = useSSL;
}
public float getLr() {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <lr> from server is: " + lr));
return lr;
}
public void setLr(float lr) {
this.lr = lr;
}
}

View File

@ -305,13 +305,6 @@ public class StartFLJob {
retCode = ResponseCode.RequestError;
return status;
}
LOGGER.info(Common.addTag("[startFLJob] set <lr> for client: " + lr));
tag = client.setLearningRate(lr);
if (!Status.SUCCESS.equals(tag)) {
LOGGER.severe(Common.addTag("[startFLJob] setLearningRate failed, return -1, please check"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[startFLJob] set <batch size> for client: " + batchSize));
client.setBatchSize(batchSize);
tag = client.updateFeatures(flParameter.getTrainModelPath(), trainFeatureMaps);
@ -351,13 +344,6 @@ public class StartFLJob {
retCode = ResponseCode.RequestError;
return status;
}
LOGGER.info(Common.addTag("[startFLJob] set <lr> for client: " + lr));
tag = client.setLearningRate(lr);
if (!Status.SUCCESS.equals(tag)) {
LOGGER.severe(Common.addTag("[startFLJob] setLearningRate failed, return -1, please check"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[startFLJob] set <batch size> for client: " + batchSize));
client.setBatchSize(batchSize);
tag = client.updateFeatures(flParameter.getTrainModelPath(), featureMaps);
@ -452,6 +438,7 @@ public class StartFLJob {
"valid, " +
"will use the default value 0.1"));
}
localFLParameter.setLr(lr);
batchSize = flPlanConfig.miniBatch();
if (Common.checkFLName(flParameter.getFlName())) {
status = deprecatedParseFeatures(flJob);

View File

@ -35,6 +35,7 @@ import mindspore.schema.ResponseGetModel;
import java.nio.ByteBuffer;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
@ -278,32 +279,47 @@ public class SyncFLJob {
*/
public int[] modelInference() {
if (Common.checkFLName(flParameter.getFlName())) {
LOGGER.warning(Common.addTag(Common.LOG_DEPRECATED));
return deprecatedModelInference();
}
return new int[0];
}
/**
* Starts an inference task on the device.
*
* @return the status code corresponding to the response message.
*/
public List<Object> modelInfer() {
Client client = ClientManager.getClient(flParameter.getFlName());
localFLParameter.setMsConfig(0, flParameter.getThreadNum(), flParameter.getCpuBindMode(), false);
localFLParameter.setStopJobFlag(false);
int[] labels = new int[0];
if (!(null == flParameter.getInputShape())) {
LOGGER.info("[model inference] the inference model has dynamic input.");
}
Map<RunType, Integer> dataSize = client.initDataSets(flParameter.getDataMap());
if (dataSize.isEmpty()) {
LOGGER.severe("[model inference] initDataSets failed, please check");
return new int[0];
client.free();
return null;
}
Status tag = client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig());
Status tag = client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig(), flParameter.getInputShape());
if (!Status.SUCCESS.equals(tag)) {
LOGGER.severe(Common.addTag("[model inference] unsolved error code in <initSessionAndInputs>: the return " +
" status is: " + tag));
return new int[0];
client.free();
return null;
}
client.setBatchSize(flParameter.getBatchSize());
LOGGER.info(Common.addTag("===========model inference============="));
labels = client.inferModel().stream().mapToInt(Integer::valueOf).toArray();
if (labels == null || labels.length == 0) {
List<Object> labels = client.inferModel();
if (labels == null || labels.size() == 0) {
LOGGER.severe("[model inference] the returned label from client.inferModel() is null, please " +
"check");
client.free();
return null;
}
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
LOGGER.info(Common.addTag("[model inference] the predicted outputs: " + Arrays.deepToString(labels.toArray())));
client.free();
LOGGER.info(Common.addTag("[model inference] inference finish"));
return labels;
@ -416,6 +432,18 @@ public class SyncFLJob {
}
}
private static int[][] getInputShapeArray(String inputShape) {
String[] inputs = inputShape.split(";");
int inputsSize = inputs.length;
int[][] inputsArray = new int[inputsSize][];
for (int i = 0; i < inputsSize; i++) {
String[] input = inputs[i].split(",");
int[] inputArray = Arrays.stream(input).mapToInt(Integer::parseInt).toArray();
inputsArray[i] = inputArray;
}
return inputsArray;
}
private static void task(String[] args) {
String trainDataPath = args[0];
String evalDataPath = args[1];
@ -438,10 +466,14 @@ public class SyncFLJob {
String inferWeightName = args[17];
String nameRegex = args[18];
String serverMod = args[19];
String inputShape = args[21];
int batchSize = Integer.parseInt(args[20]);
FLParameter flParameter = FLParameter.getInstance();
if (!("null".equals(inputShape) || inputShape == null)) {
flParameter.setInputShape(getInputShapeArray(inputShape));
}
// create dataset of map
Map<RunType, List<String>> dataMap = createDatasetMap(trainDataPath, evalDataPath, inferDataPath, pathRegex);
@ -476,7 +508,7 @@ public class SyncFLJob {
flParameter.setThreadNum(threadNum);
flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode));
flParameter.setBatchSize(batchSize);
syncFLJob.modelInference();
syncFLJob.modelInfer();
break;
case "getModel":
LOGGER.info(Common.addTag("start syncFLJob.getModel()"));

View File

@ -32,7 +32,7 @@ import java.util.logging.Logger;
public abstract class Callback {
private static final Logger logger = Logger.getLogger(LossCallback.class.toString());
LiteSession session;
protected LiteSession session;
public int steps = 0;

View File

@ -17,10 +17,10 @@
package com.mindspore.flclient.model;
import com.mindspore.flclient.Common;
import com.mindspore.flclient.FLClientStatus;
import com.mindspore.flclient.LocalFLParameter;
import com.mindspore.lite.LiteSession;
import com.mindspore.lite.MSTensor;
import com.mindspore.lite.Model;
import com.mindspore.lite.TrainSession;
import com.mindspore.lite.config.MSConfig;
import mindspore.schema.FeatureMap;
@ -28,11 +28,13 @@ import mindspore.schema.FeatureMap;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.logging.Logger;
import java.util.stream.IntStream;
/**
* Defining the client base class.
@ -85,7 +87,7 @@ public abstract class Client {
* @param inferCallback callback used for infer model.
* @return infer result.
*/
public abstract List<Integer> getInferResult(List<Callback> inferCallback);
public abstract List<Object> getInferResult(List<Callback> inferCallback);
/**
* Init lite session and inputs buffer.
@ -94,23 +96,37 @@ public abstract class Client {
* @param config session run config.
* @return execute status.
*/
public Status initSessionAndInputs(String modelPath, MSConfig config) {
public Status initSessionAndInputs(String modelPath, MSConfig config, int[][] inputShapes) {
if (modelPath == null) {
logger.severe(Common.addTag("session init failed"));
return Status.NULLPTR;
return Status.FAILED;
}
Optional<LiteSession> optTrainSession = initSession(modelPath, config);
Optional<LiteSession> optTrainSession = initSession(modelPath, config, inputShapes != null);
if (!optTrainSession.isPresent()) {
logger.severe(Common.addTag("session init failed"));
return Status.NULLPTR;
return Status.FAILED;
}
trainSession = optTrainSession.get();
inputsBuffer.clear();
List<MSTensor> inputs = trainSession.getInputs();
for (MSTensor input : inputs) {
ByteBuffer inputBuffer = ByteBuffer.allocateDirect((int) input.size());
inputBuffer.order(ByteOrder.nativeOrder());
inputsBuffer.add(inputBuffer);
if (inputShapes == null) {
List<MSTensor> inputs = trainSession.getInputs();
for (MSTensor input : inputs) {
ByteBuffer inputBuffer = ByteBuffer.allocateDirect((int) input.size());
inputBuffer.order(ByteOrder.nativeOrder());
inputsBuffer.add(inputBuffer);
}
} else {
boolean isSuccess = trainSession.resize(trainSession.getInputs(), inputShapes);
if (!isSuccess) {
logger.severe(Common.addTag("session resize failed"));
return Status.FAILED;
}
for (int[] shapes : inputShapes) {
int size = IntStream.of(shapes).reduce((a, b) -> a * b).getAsInt() * Integer.BYTES;
ByteBuffer inputBuffer = ByteBuffer.allocateDirect(size);
inputBuffer.order(ByteOrder.nativeOrder());
inputsBuffer.add(inputBuffer);
}
}
return Status.SUCCESS;
}
@ -182,7 +198,7 @@ public abstract class Client {
*
* @return infer status.
*/
public List<Integer> inferModel() {
public List<Object> inferModel() {
boolean isSuccess = trainSession.eval();
if (!isSuccess) {
logger.severe(Common.addTag("train session switch eval mode failed"));
@ -246,17 +262,44 @@ public abstract class Client {
return Status.SUCCESS;
}
private Optional<LiteSession> initSession(String modelPath, MSConfig msConfig) {
private Optional<LiteSession> initSession(String modelPath, MSConfig msConfig, boolean isDynamicInferModel) {
if (modelPath == null) {
logger.severe(Common.addTag("modelPath cannot be empty"));
return Optional.empty();
}
LiteSession trainSession = TrainSession.createTrainSession(modelPath, msConfig, false);
if (trainSession == null) {
logger.severe(Common.addTag("init session failed,please check model path:" + modelPath));
return Optional.empty();
// only lite session support dynamic shape
if (isDynamicInferModel) {
Model model = new Model();
boolean isSuccess = model.loadModel(modelPath);
if (!isSuccess) {
logger.severe(Common.addTag("load model failed:" + modelPath));
return Optional.empty();
}
trainSession = LiteSession.createSession(msConfig);
if (trainSession == null) {
logger.severe(Common.addTag("init session failed,please check model path:" + modelPath));
msConfig.free();
model.free();
return Optional.empty();
}
msConfig.free();
isSuccess = trainSession.compileGraph(model);
if (!isSuccess) {
logger.severe(Common.addTag("compile graph failed:" + modelPath));
model.free();
trainSession.free();
return Optional.empty();
}
model.free();
return Optional.of(trainSession);
} else {
trainSession = TrainSession.createTrainSession(modelPath, msConfig, false);
if (trainSession == null) {
logger.severe(Common.addTag("init session failed,please check model path:" + modelPath));
return Optional.empty();
}
return Optional.of(trainSession);
}
return Optional.of(trainSession);
}
/**

View File

@ -116,6 +116,12 @@ def get_parser():
- client_num
Specifies the number of clients. The value must be the same as that of 'start_fl_job_cnt' when the server is
started. This parameter is not required in actual scenarios.
- input_shape
string type, this parameter is used to set the real input shape of the model, where the value corresponding to
the batch size in the inputs shape should be consistent with the parameter 'batch_size'. When the model contains
only one input, use a comma ',' to connect the dimensions of each dimension, such as "32,64,96"; when the model
contains multiple inputs, use a semicolon ';' to connect the shapes corresponding to different inputs, such as
"32,64,96;12,24,26". Notice, this parameter is only required when the inference model supports dynamic inputs.
"""
parser = argparse.ArgumentParser(description="Run SyncFLJob.java case")
@ -150,6 +156,7 @@ def get_parser():
parser.add_argument("--name_regex", type=str, default=",")
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--input_shape", type=str, default="null")
parser.add_argument("--client_num", type=int, default=0)
return parser
@ -188,6 +195,7 @@ infer_weight_name = args.infer_weight_name
name_regex = args.name_regex
server_mode = args.server_mode
batch_size = args.batch_size
input_shape = args.input_shape
client_num = args.client_num
@ -274,6 +282,7 @@ for i in range(client_num):
cmd_client += name_regex + " "
cmd_client += server_mode + " "
cmd_client += str(batch_size) + " "
cmd_client += input_shape + " "
cmd_client += " > client" + ".log 2>&1 &"
print(cmd_client)
subprocess.call(['bash', '-c', cmd_client])