!29147 [MS][LITE] fl support dynamic input shape

Merge pull request !29147 from zhengjun10/java
This commit is contained in:
i-robot 2022-01-19 06:56:28 +00:00 committed by Gitee
commit f4cc31a7c4
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 50 additions and 6 deletions

View File

@ -17,10 +17,10 @@
package com.mindspore.flclient.model;
import com.mindspore.flclient.Common;
import com.mindspore.flclient.FLClientStatus;
import com.mindspore.flclient.LocalFLParameter;
import com.mindspore.lite.LiteSession;
import com.mindspore.lite.MSTensor;
import com.mindspore.lite.Model;
import com.mindspore.lite.TrainSession;
import com.mindspore.lite.config.MSConfig;
import mindspore.schema.FeatureMap;
@ -33,6 +33,7 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.logging.Logger;
import java.util.stream.IntStream;
/**
* Defining the client base class.
@ -52,6 +53,8 @@ public abstract class Client {
*/
public Map<RunType, DataSet> dataSets = new HashMap<>();
private boolean isDynamicInferModel = false;
private final List<ByteBuffer> inputsBuffer = new ArrayList<>();
/**
@ -99,6 +102,9 @@ public abstract class Client {
logger.severe(Common.addTag("session init failed"));
return Status.NULLPTR;
}
if (isDynamicInferModel) {
logger.info(Common.addTag(modelPath + " is dynamic input"));
}
Optional<LiteSession> optTrainSession = initSession(modelPath, config);
if (!optTrainSession.isPresent()) {
logger.severe(Common.addTag("session init failed"));
@ -251,12 +257,39 @@ public abstract class Client {
logger.severe(Common.addTag("modelPath cannot be empty"));
return Optional.empty();
}
LiteSession trainSession = TrainSession.createTrainSession(modelPath, msConfig, false);
if (trainSession == null) {
logger.severe(Common.addTag("init session failed,please check model path:" + modelPath));
return Optional.empty();
// only lite session support dynamic shape
if (isDynamicInferModel) {
Model model = new Model();
boolean isSuccess = model.loadModel(modelPath);
if (!isSuccess) {
logger.severe(Common.addTag("load model failed:" + modelPath));
return Optional.empty();
}
trainSession = LiteSession.createSession(msConfig);
if (trainSession == null) {
logger.severe(Common.addTag("init session failed,please check model path:" + modelPath));
msConfig.free();
model.free();
return Optional.empty();
}
msConfig.free();
isSuccess = trainSession.compileGraph(model);
if (!isSuccess) {
logger.severe(Common.addTag("compile graph failed:" + modelPath));
model.free();
trainSession.free();
return Optional.empty();
}
model.free();
return Optional.of(trainSession);
} else {
trainSession = TrainSession.createTrainSession(modelPath, msConfig, false);
if (trainSession == null) {
logger.severe(Common.addTag("init session failed,please check model path:" + modelPath));
return Optional.empty();
}
return Optional.of(trainSession);
}
return Optional.of(trainSession);
}
/**
@ -346,4 +379,15 @@ public abstract class Client {
dataset.batchSize = batchSize;
}
}
/**
* Resize client input shape dims.
*
* @param dims new input shapes.
* @return resize status.
*/
public boolean resize(int[][] dims) {
isDynamicInferModel = true;
return trainSession.resize(trainSession.getInputs(), dims);
}
}