forked from mindspore-Ecosystem/mindspore
!29147 [MS][LITE] fl support dynamic input shape
Merge pull request !29147 from zhengjun10/java
This commit is contained in:
commit
f4cc31a7c4
|
@ -17,10 +17,10 @@
|
|||
package com.mindspore.flclient.model;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
import com.mindspore.flclient.FLClientStatus;
|
||||
import com.mindspore.flclient.LocalFLParameter;
|
||||
import com.mindspore.lite.LiteSession;
|
||||
import com.mindspore.lite.MSTensor;
|
||||
import com.mindspore.lite.Model;
|
||||
import com.mindspore.lite.TrainSession;
|
||||
import com.mindspore.lite.config.MSConfig;
|
||||
import mindspore.schema.FeatureMap;
|
||||
|
@ -33,6 +33,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.logging.Logger;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
/**
|
||||
* Defining the client base class.
|
||||
|
@ -52,6 +53,8 @@ public abstract class Client {
|
|||
*/
|
||||
public Map<RunType, DataSet> dataSets = new HashMap<>();
|
||||
|
||||
private boolean isDynamicInferModel = false;
|
||||
|
||||
private final List<ByteBuffer> inputsBuffer = new ArrayList<>();
|
||||
|
||||
/**
|
||||
|
@ -99,6 +102,9 @@ public abstract class Client {
|
|||
logger.severe(Common.addTag("session init failed"));
|
||||
return Status.NULLPTR;
|
||||
}
|
||||
if (isDynamicInferModel) {
|
||||
logger.info(Common.addTag(modelPath + " is dynamic input"));
|
||||
}
|
||||
Optional<LiteSession> optTrainSession = initSession(modelPath, config);
|
||||
if (!optTrainSession.isPresent()) {
|
||||
logger.severe(Common.addTag("session init failed"));
|
||||
|
@ -251,12 +257,39 @@ public abstract class Client {
|
|||
logger.severe(Common.addTag("modelPath cannot be empty"));
|
||||
return Optional.empty();
|
||||
}
|
||||
LiteSession trainSession = TrainSession.createTrainSession(modelPath, msConfig, false);
|
||||
if (trainSession == null) {
|
||||
logger.severe(Common.addTag("init session failed,please check model path:" + modelPath));
|
||||
return Optional.empty();
|
||||
// only lite session support dynamic shape
|
||||
if (isDynamicInferModel) {
|
||||
Model model = new Model();
|
||||
boolean isSuccess = model.loadModel(modelPath);
|
||||
if (!isSuccess) {
|
||||
logger.severe(Common.addTag("load model failed:" + modelPath));
|
||||
return Optional.empty();
|
||||
}
|
||||
trainSession = LiteSession.createSession(msConfig);
|
||||
if (trainSession == null) {
|
||||
logger.severe(Common.addTag("init session failed,please check model path:" + modelPath));
|
||||
msConfig.free();
|
||||
model.free();
|
||||
return Optional.empty();
|
||||
}
|
||||
msConfig.free();
|
||||
isSuccess = trainSession.compileGraph(model);
|
||||
if (!isSuccess) {
|
||||
logger.severe(Common.addTag("compile graph failed:" + modelPath));
|
||||
model.free();
|
||||
trainSession.free();
|
||||
return Optional.empty();
|
||||
}
|
||||
model.free();
|
||||
return Optional.of(trainSession);
|
||||
} else {
|
||||
trainSession = TrainSession.createTrainSession(modelPath, msConfig, false);
|
||||
if (trainSession == null) {
|
||||
logger.severe(Common.addTag("init session failed,please check model path:" + modelPath));
|
||||
return Optional.empty();
|
||||
}
|
||||
return Optional.of(trainSession);
|
||||
}
|
||||
return Optional.of(trainSession);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -346,4 +379,15 @@ public abstract class Client {
|
|||
dataset.batchSize = batchSize;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resize client input shape dims.
|
||||
*
|
||||
* @param dims new input shapes.
|
||||
* @return resize status.
|
||||
*/
|
||||
public boolean resize(int[][] dims) {
|
||||
isDynamicInferModel = true;
|
||||
return trainSession.resize(trainSession.getInputs(), dims);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue