add DynamicInferModel tag for flclient in r1.6
This commit is contained in:
parent
76bca2addc
commit
15ec4f5af2
|
@ -519,7 +519,7 @@ public class Common {
|
|||
LOGGER.info(Common.addTag("==========Loading model, " + modelPath + " Create " +
|
||||
" Session============="));
|
||||
Client client = ClientManager.getClient(flParameter.getFlName());
|
||||
tag = client.initSessionAndInputs(modelPath, localFLParameter.getMsConfig());
|
||||
tag = client.initSessionAndInputs(modelPath, localFLParameter.getMsConfig(), flParameter.getInputShape());
|
||||
if (!Status.SUCCESS.equals(tag)) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return " +
|
||||
"is -1"));
|
||||
|
|
|
@ -315,7 +315,14 @@ public class FLLiteClient {
|
|||
}
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
LOGGER.info(Common.addTag("[train] train in " + flParameter.getFlName()));
|
||||
Status tag = client.trainModel(epochs);
|
||||
LOGGER.info(Common.addTag("[train] lr for client is: " + localFLParameter.getLr()));
|
||||
Status tag = client.setLearningRate(localFLParameter.getLr());
|
||||
if (!Status.SUCCESS.equals(tag)) {
|
||||
LOGGER.severe(Common.addTag("[train] setLearningRate failed, return -1, please check"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
tag = client.trainModel(epochs);
|
||||
if (!Status.SUCCESS.equals(tag)) {
|
||||
failed("[train] unsolved error code in <client.trainModel>", ResponseCode.RequestError);
|
||||
}
|
||||
|
@ -568,12 +575,12 @@ public class FLLiteClient {
|
|||
float acc = 0;
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
|
||||
client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig());
|
||||
client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig(), flParameter.getInputShape());
|
||||
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath()));
|
||||
acc = client.evalModel();
|
||||
} else {
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
|
||||
client.initSessionAndInputs(flParameter.getTrainModelPath(), localFLParameter.getMsConfig());
|
||||
client.initSessionAndInputs(flParameter.getTrainModelPath(), localFLParameter.getMsConfig(), flParameter.getInputShape());
|
||||
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getTrainModelPath()));
|
||||
acc = client.evalModel();
|
||||
}
|
||||
|
|
|
@ -83,6 +83,7 @@ public class FLParameter {
|
|||
private Map<RunType, List<String>> dataMap = new HashMap<>();
|
||||
private ServerMod serverMod;
|
||||
private int batchSize;
|
||||
private int[][] inputShape;
|
||||
|
||||
private FLParameter() {
|
||||
clientID = UUID.randomUUID().toString();
|
||||
|
@ -575,4 +576,12 @@ public class FLParameter {
|
|||
}
|
||||
this.batchSize = batchSize;
|
||||
}
|
||||
|
||||
public int[][] getInputShape() {
|
||||
return inputShape;
|
||||
}
|
||||
|
||||
public void setInputShape(int[][] inputShape) {
|
||||
this.inputShape = inputShape;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -82,6 +82,8 @@ public class LocalFLParameter {
|
|||
private boolean stopJobFlag = false;
|
||||
private MSConfig msConfig = new MSConfig();
|
||||
private boolean useSSL = true;
|
||||
private float lr = 0.1f;
|
||||
|
||||
|
||||
private LocalFLParameter() {
|
||||
// set classifierWeightName albertWeightName
|
||||
|
@ -232,7 +234,6 @@ public class LocalFLParameter {
|
|||
msConfig.init(DeviceType, threadNum, cpuBindMode, enable_fp16);
|
||||
}
|
||||
|
||||
|
||||
public boolean isUseSSL() {
|
||||
return useSSL;
|
||||
}
|
||||
|
@ -240,4 +241,13 @@ public class LocalFLParameter {
|
|||
public void setUseSSL(boolean useSSL) {
|
||||
this.useSSL = useSSL;
|
||||
}
|
||||
|
||||
public float getLr() {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <lr> from server is: " + lr));
|
||||
return lr;
|
||||
}
|
||||
|
||||
public void setLr(float lr) {
|
||||
this.lr = lr;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -305,13 +305,6 @@ public class StartFLJob {
|
|||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] set <lr> for client: " + lr));
|
||||
tag = client.setLearningRate(lr);
|
||||
if (!Status.SUCCESS.equals(tag)) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] setLearningRate failed, return -1, please check"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] set <batch size> for client: " + batchSize));
|
||||
client.setBatchSize(batchSize);
|
||||
tag = client.updateFeatures(flParameter.getTrainModelPath(), trainFeatureMaps);
|
||||
|
@ -351,13 +344,6 @@ public class StartFLJob {
|
|||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] set <lr> for client: " + lr));
|
||||
tag = client.setLearningRate(lr);
|
||||
if (!Status.SUCCESS.equals(tag)) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] setLearningRate failed, return -1, please check"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] set <batch size> for client: " + batchSize));
|
||||
client.setBatchSize(batchSize);
|
||||
tag = client.updateFeatures(flParameter.getTrainModelPath(), featureMaps);
|
||||
|
@ -452,6 +438,7 @@ public class StartFLJob {
|
|||
"valid, " +
|
||||
"will use the default value 0.1"));
|
||||
}
|
||||
localFLParameter.setLr(lr);
|
||||
batchSize = flPlanConfig.miniBatch();
|
||||
if (Common.checkFLName(flParameter.getFlName())) {
|
||||
status = deprecatedParseFeatures(flJob);
|
||||
|
|
|
@ -35,6 +35,7 @@ import mindspore.schema.ResponseGetModel;
|
|||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.SecureRandom;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
@ -278,32 +279,47 @@ public class SyncFLJob {
|
|||
*/
|
||||
public int[] modelInference() {
|
||||
if (Common.checkFLName(flParameter.getFlName())) {
|
||||
LOGGER.warning(Common.addTag(Common.LOG_DEPRECATED));
|
||||
return deprecatedModelInference();
|
||||
}
|
||||
return new int[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Starts an inference task on the device.
|
||||
*
|
||||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public List<Object> modelInfer() {
|
||||
Client client = ClientManager.getClient(flParameter.getFlName());
|
||||
localFLParameter.setMsConfig(0, flParameter.getThreadNum(), flParameter.getCpuBindMode(), false);
|
||||
localFLParameter.setStopJobFlag(false);
|
||||
int[] labels = new int[0];
|
||||
if (!(null == flParameter.getInputShape())) {
|
||||
LOGGER.info("[model inference] the inference model has dynamic input.");
|
||||
}
|
||||
Map<RunType, Integer> dataSize = client.initDataSets(flParameter.getDataMap());
|
||||
if (dataSize.isEmpty()) {
|
||||
LOGGER.severe("[model inference] initDataSets failed, please check");
|
||||
return new int[0];
|
||||
client.free();
|
||||
return null;
|
||||
}
|
||||
Status tag = client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig());
|
||||
Status tag = client.initSessionAndInputs(flParameter.getInferModelPath(), localFLParameter.getMsConfig(), flParameter.getInputShape());
|
||||
if (!Status.SUCCESS.equals(tag)) {
|
||||
LOGGER.severe(Common.addTag("[model inference] unsolved error code in <initSessionAndInputs>: the return " +
|
||||
" status is: " + tag));
|
||||
return new int[0];
|
||||
client.free();
|
||||
return null;
|
||||
}
|
||||
client.setBatchSize(flParameter.getBatchSize());
|
||||
LOGGER.info(Common.addTag("===========model inference============="));
|
||||
labels = client.inferModel().stream().mapToInt(Integer::valueOf).toArray();
|
||||
if (labels == null || labels.length == 0) {
|
||||
List<Object> labels = client.inferModel();
|
||||
if (labels == null || labels.size() == 0) {
|
||||
LOGGER.severe("[model inference] the returned label from client.inferModel() is null, please " +
|
||||
"check");
|
||||
client.free();
|
||||
return null;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
|
||||
LOGGER.info(Common.addTag("[model inference] the predicted outputs: " + Arrays.deepToString(labels.toArray())));
|
||||
client.free();
|
||||
LOGGER.info(Common.addTag("[model inference] inference finish"));
|
||||
return labels;
|
||||
|
@ -416,6 +432,18 @@ public class SyncFLJob {
|
|||
}
|
||||
}
|
||||
|
||||
private static int[][] getInputShapeArray(String inputShape) {
|
||||
String[] inputs = inputShape.split(";");
|
||||
int inputsSize = inputs.length;
|
||||
int[][] inputsArray = new int[inputsSize][];
|
||||
for (int i = 0; i < inputsSize; i++) {
|
||||
String[] input = inputs[i].split(",");
|
||||
int[] inputArray = Arrays.stream(input).mapToInt(Integer::parseInt).toArray();
|
||||
inputsArray[i] = inputArray;
|
||||
}
|
||||
return inputsArray;
|
||||
}
|
||||
|
||||
private static void task(String[] args) {
|
||||
String trainDataPath = args[0];
|
||||
String evalDataPath = args[1];
|
||||
|
@ -438,10 +466,14 @@ public class SyncFLJob {
|
|||
String inferWeightName = args[17];
|
||||
String nameRegex = args[18];
|
||||
String serverMod = args[19];
|
||||
String inputShape = args[21];
|
||||
int batchSize = Integer.parseInt(args[20]);
|
||||
|
||||
FLParameter flParameter = FLParameter.getInstance();
|
||||
|
||||
if (!("null".equals(inputShape) || inputShape == null)) {
|
||||
flParameter.setInputShape(getInputShapeArray(inputShape));
|
||||
}
|
||||
|
||||
// create dataset of map
|
||||
Map<RunType, List<String>> dataMap = createDatasetMap(trainDataPath, evalDataPath, inferDataPath, pathRegex);
|
||||
|
||||
|
@ -476,7 +508,7 @@ public class SyncFLJob {
|
|||
flParameter.setThreadNum(threadNum);
|
||||
flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode));
|
||||
flParameter.setBatchSize(batchSize);
|
||||
syncFLJob.modelInference();
|
||||
syncFLJob.modelInfer();
|
||||
break;
|
||||
case "getModel":
|
||||
LOGGER.info(Common.addTag("start syncFLJob.getModel()"));
|
||||
|
|
|
@ -32,7 +32,7 @@ import java.util.logging.Logger;
|
|||
public abstract class Callback {
|
||||
private static final Logger logger = Logger.getLogger(LossCallback.class.toString());
|
||||
|
||||
LiteSession session;
|
||||
protected LiteSession session;
|
||||
|
||||
public int steps = 0;
|
||||
|
||||
|
|
|
@ -17,10 +17,10 @@
|
|||
package com.mindspore.flclient.model;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
import com.mindspore.flclient.FLClientStatus;
|
||||
import com.mindspore.flclient.LocalFLParameter;
|
||||
import com.mindspore.lite.LiteSession;
|
||||
import com.mindspore.lite.MSTensor;
|
||||
import com.mindspore.lite.Model;
|
||||
import com.mindspore.lite.TrainSession;
|
||||
import com.mindspore.lite.config.MSConfig;
|
||||
import mindspore.schema.FeatureMap;
|
||||
|
@ -28,11 +28,13 @@ import mindspore.schema.FeatureMap;
|
|||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.logging.Logger;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
/**
|
||||
* Defining the client base class.
|
||||
|
@ -85,7 +87,7 @@ public abstract class Client {
|
|||
* @param inferCallback callback used for infer model.
|
||||
* @return infer result.
|
||||
*/
|
||||
public abstract List<Integer> getInferResult(List<Callback> inferCallback);
|
||||
public abstract List<Object> getInferResult(List<Callback> inferCallback);
|
||||
|
||||
/**
|
||||
* Init lite session and inputs buffer.
|
||||
|
@ -94,23 +96,37 @@ public abstract class Client {
|
|||
* @param config session run config.
|
||||
* @return execute status.
|
||||
*/
|
||||
public Status initSessionAndInputs(String modelPath, MSConfig config) {
|
||||
public Status initSessionAndInputs(String modelPath, MSConfig config, int[][] inputShapes) {
|
||||
if (modelPath == null) {
|
||||
logger.severe(Common.addTag("session init failed"));
|
||||
return Status.NULLPTR;
|
||||
return Status.FAILED;
|
||||
}
|
||||
Optional<LiteSession> optTrainSession = initSession(modelPath, config);
|
||||
Optional<LiteSession> optTrainSession = initSession(modelPath, config, inputShapes != null);
|
||||
if (!optTrainSession.isPresent()) {
|
||||
logger.severe(Common.addTag("session init failed"));
|
||||
return Status.NULLPTR;
|
||||
return Status.FAILED;
|
||||
}
|
||||
trainSession = optTrainSession.get();
|
||||
inputsBuffer.clear();
|
||||
List<MSTensor> inputs = trainSession.getInputs();
|
||||
for (MSTensor input : inputs) {
|
||||
ByteBuffer inputBuffer = ByteBuffer.allocateDirect((int) input.size());
|
||||
inputBuffer.order(ByteOrder.nativeOrder());
|
||||
inputsBuffer.add(inputBuffer);
|
||||
if (inputShapes == null) {
|
||||
List<MSTensor> inputs = trainSession.getInputs();
|
||||
for (MSTensor input : inputs) {
|
||||
ByteBuffer inputBuffer = ByteBuffer.allocateDirect((int) input.size());
|
||||
inputBuffer.order(ByteOrder.nativeOrder());
|
||||
inputsBuffer.add(inputBuffer);
|
||||
}
|
||||
} else {
|
||||
boolean isSuccess = trainSession.resize(trainSession.getInputs(), inputShapes);
|
||||
if (!isSuccess) {
|
||||
logger.severe(Common.addTag("session resize failed"));
|
||||
return Status.FAILED;
|
||||
}
|
||||
for (int[] shapes : inputShapes) {
|
||||
int size = IntStream.of(shapes).reduce((a, b) -> a * b).getAsInt() * Integer.BYTES;
|
||||
ByteBuffer inputBuffer = ByteBuffer.allocateDirect(size);
|
||||
inputBuffer.order(ByteOrder.nativeOrder());
|
||||
inputsBuffer.add(inputBuffer);
|
||||
}
|
||||
}
|
||||
return Status.SUCCESS;
|
||||
}
|
||||
|
@ -182,7 +198,7 @@ public abstract class Client {
|
|||
*
|
||||
* @return infer status.
|
||||
*/
|
||||
public List<Integer> inferModel() {
|
||||
public List<Object> inferModel() {
|
||||
boolean isSuccess = trainSession.eval();
|
||||
if (!isSuccess) {
|
||||
logger.severe(Common.addTag("train session switch eval mode failed"));
|
||||
|
@ -246,17 +262,44 @@ public abstract class Client {
|
|||
return Status.SUCCESS;
|
||||
}
|
||||
|
||||
private Optional<LiteSession> initSession(String modelPath, MSConfig msConfig) {
|
||||
private Optional<LiteSession> initSession(String modelPath, MSConfig msConfig, boolean isDynamicInferModel) {
|
||||
if (modelPath == null) {
|
||||
logger.severe(Common.addTag("modelPath cannot be empty"));
|
||||
return Optional.empty();
|
||||
}
|
||||
LiteSession trainSession = TrainSession.createTrainSession(modelPath, msConfig, false);
|
||||
if (trainSession == null) {
|
||||
logger.severe(Common.addTag("init session failed,please check model path:" + modelPath));
|
||||
return Optional.empty();
|
||||
// only lite session support dynamic shape
|
||||
if (isDynamicInferModel) {
|
||||
Model model = new Model();
|
||||
boolean isSuccess = model.loadModel(modelPath);
|
||||
if (!isSuccess) {
|
||||
logger.severe(Common.addTag("load model failed:" + modelPath));
|
||||
return Optional.empty();
|
||||
}
|
||||
trainSession = LiteSession.createSession(msConfig);
|
||||
if (trainSession == null) {
|
||||
logger.severe(Common.addTag("init session failed,please check model path:" + modelPath));
|
||||
msConfig.free();
|
||||
model.free();
|
||||
return Optional.empty();
|
||||
}
|
||||
msConfig.free();
|
||||
isSuccess = trainSession.compileGraph(model);
|
||||
if (!isSuccess) {
|
||||
logger.severe(Common.addTag("compile graph failed:" + modelPath));
|
||||
model.free();
|
||||
trainSession.free();
|
||||
return Optional.empty();
|
||||
}
|
||||
model.free();
|
||||
return Optional.of(trainSession);
|
||||
} else {
|
||||
trainSession = TrainSession.createTrainSession(modelPath, msConfig, false);
|
||||
if (trainSession == null) {
|
||||
logger.severe(Common.addTag("init session failed,please check model path:" + modelPath));
|
||||
return Optional.empty();
|
||||
}
|
||||
return Optional.of(trainSession);
|
||||
}
|
||||
return Optional.of(trainSession);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -116,6 +116,12 @@ def get_parser():
|
|||
- client_num
|
||||
Specifies the number of clients. The value must be the same as that of 'start_fl_job_cnt' when the server is
|
||||
started. This parameter is not required in actual scenarios.
|
||||
- input_shape
|
||||
string type, this parameter is used to set the real input shape of the model, where the value corresponding to
|
||||
the batch size in the inputs shape should be consistent with the parameter 'batch_size'. When the model contains
|
||||
only one input, use a comma ',' to connect the dimensions of each dimension, such as "32,64,96"; when the model
|
||||
contains multiple inputs, use a semicolon ';' to connect the shapes corresponding to different inputs, such as
|
||||
"32,64,96;12,24,26". Notice, this parameter is only required when the inference model supports dynamic inputs.
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run SyncFLJob.java case")
|
||||
|
@ -150,6 +156,7 @@ def get_parser():
|
|||
parser.add_argument("--name_regex", type=str, default=",")
|
||||
parser.add_argument("--server_mode", type=str, default="FEDERATED_LEARNING")
|
||||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
parser.add_argument("--input_shape", type=str, default="null")
|
||||
|
||||
parser.add_argument("--client_num", type=int, default=0)
|
||||
return parser
|
||||
|
@ -188,6 +195,7 @@ infer_weight_name = args.infer_weight_name
|
|||
name_regex = args.name_regex
|
||||
server_mode = args.server_mode
|
||||
batch_size = args.batch_size
|
||||
input_shape = args.input_shape
|
||||
|
||||
client_num = args.client_num
|
||||
|
||||
|
@ -274,6 +282,7 @@ for i in range(client_num):
|
|||
cmd_client += name_regex + " "
|
||||
cmd_client += server_mode + " "
|
||||
cmd_client += str(batch_size) + " "
|
||||
cmd_client += input_shape + " "
|
||||
cmd_client += " > client" + ".log 2>&1 &"
|
||||
print(cmd_client)
|
||||
subprocess.call(['bash', '-c', cmd_client])
|
||||
|
|
Loading…
Reference in New Issue