sync examples to new java api

This commit is contained in:
zhengjun10 2022-01-10 16:28:47 +08:00
parent d5502bab19
commit 940e31860d
7 changed files with 362 additions and 207 deletions

View File

@ -1,12 +1,12 @@
package com.mindspore.lite.demo;
import com.mindspore.lite.LiteSession;
import com.mindspore.lite.MSTensor;
import com.mindspore.lite.Model;
import com.mindspore.lite.DataType;
import com.mindspore.lite.Version;
import com.mindspore.lite.config.MSConfig;
import com.mindspore.lite.config.DeviceType;
import com.mindspore.MSTensor;
import com.mindspore.Model;
import com.mindspore.config.DataType;
import com.mindspore.config.DeviceType;
import com.mindspore.config.MSContext;
import com.mindspore.config.ModelType;
import com.mindspore.config.Version;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
@ -15,7 +15,6 @@ import java.util.Random;
public class Main {
private static Model model;
private static LiteSession session;
public static float[] generateArray(int len) {
Random rand = new Random();
@ -37,28 +36,20 @@ public class Main {
return buffer;
}
private static boolean compile() {
MSConfig msConfig = new MSConfig();
// You can set config through Init Api or use the default parameters directly.
// The default parameter is that the backend type is DeviceType.DT_CPU, and the number of threads is 2.
boolean ret = msConfig.init(DeviceType.DT_CPU, 2);
private static boolean compile(String modelPath) {
MSContext context = new MSContext();
// use default param init context
context.init();
boolean ret = context.addDeviceInfo(DeviceType.DT_CPU, false, 0);
if (!ret) {
System.err.println("Init context failed");
System.err.println("Compile graph failed");
context.free();
return false;
}
// Create the MindSpore lite session.
session = new LiteSession();
ret = session.init(msConfig);
msConfig.free();
if (!ret) {
System.err.println("Create session failed");
model.free();
return false;
}
model = new Model();
// Compile graph.
ret = session.compileGraph(model);
ret = model.build(modelPath, ModelType.MT_MINDIR, context);
if (!ret) {
System.err.println("Compile graph failed");
model.free();
@ -68,7 +59,7 @@ public class Main {
}
private static boolean run() {
MSTensor inputTensor = session.getInputsByTensorName("graph_input-173");
MSTensor inputTensor = model.getInputByTensorName("graph_input-173");
if (inputTensor.getDataType() != DataType.kNumberTypeFloat32) {
System.err.println("Input tensor shape do not float, the data type is " + inputTensor.getDataType());
return false;
@ -82,14 +73,14 @@ public class Main {
inputTensor.setData(inputData);
// Run Inference.
boolean ret = session.runGraph();
boolean ret = model.predict();
if (!ret) {
System.err.println("MindSpore Lite run failed.");
return false;
}
// Get Output Tensor Data.
MSTensor outTensor = session.getOutputByTensorName("Softmax-65");
MSTensor outTensor = model.getOutputByTensorName("Softmax-65");
// Print out Tensor Data.
StringBuilder msgSb = new StringBuilder();
@ -117,7 +108,6 @@ public class Main {
}
private static void freeBuffer() {
session.free();
model.free();
}
@ -128,14 +118,7 @@ public class Main {
return;
}
String modelPath = args[0];
model = new Model();
boolean ret = model.loadModel(modelPath);
if (!ret) {
System.err.println("Load model failed, model path is " + modelPath);
return;
}
ret = compile();
boolean ret = compile(modelPath);
if (!ret) {
System.err.println("MindSpore Lite compile failed.");
return;

View File

@ -8,33 +8,29 @@ import android.view.View;
import android.widget.TextView;
import android.widget.Toast;
import com.mindspore.lite.DataType;
import com.mindspore.lite.LiteSession;
import com.mindspore.lite.MSTensor;
import com.mindspore.lite.Model;
import com.mindspore.lite.config.CpuBindMode;
import com.mindspore.lite.config.DeviceType;
import com.mindspore.lite.config.MSConfig;
import com.mindspore.lite.Version;
import com.mindspore.MSTensor;
import com.mindspore.Model;
import com.mindspore.config.CpuBindMode;
import com.mindspore.config.DataType;
import com.mindspore.config.MSContext;
import com.mindspore.config.ModelType;
import com.mindspore.config.Version;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
public class MainActivity extends AppCompatActivity {
private String TAG = "MS_LITE";
private Model model;
private LiteSession session1;
private LiteSession session2;
private boolean session1Finish = true;
private boolean session2Finish = true;
private boolean session1Compile = false;
private boolean session2Compile = false;
private Model model1;
private Model model2;
private boolean model1Finish = true;
private boolean model2Finish = true;
private boolean model1Compile = false;
private boolean model2Compile = false;
public float[] generateArray(int len) {
Random rand = new Random();
@ -56,61 +52,58 @@ public class MainActivity extends AppCompatActivity {
return buffer.array();
}
private MSConfig createCPUConfig() {
MSConfig msConfig = new MSConfig();
boolean ret = msConfig.init(DeviceType.DT_CPU, 2, CpuBindMode.HIGHER_CPU, true);
private MSContext createCPUConfig() {
MSContext context = new MSContext();
context.init(2, CpuBindMode.HIGHER_CPU, false);
boolean ret = context.addDeviceInfo(com.mindspore.config.DeviceType.DT_CPU, false, 0);
if (!ret) {
Log.e(TAG, "Create CPU Config failed.");
return null;
}
return msConfig;
return context;
}
private MSConfig createGPUConfig() {
MSConfig msConfig = new MSConfig();
boolean ret = msConfig.init(DeviceType.DT_GPU, 2, CpuBindMode.MID_CPU, true);
private MSContext createGPUConfig() {
MSContext context = new MSContext();
context.init(2, CpuBindMode.MID_CPU, false);
boolean ret = context.addDeviceInfo(com.mindspore.config.DeviceType.DT_GPU, true, 0);
if (!ret) {
Log.e(TAG, "Create GPU Config failed.");
return null;
}
return msConfig;
return context;
}
private LiteSession createLiteSession(boolean isResize) {
MSConfig msConfig = createCPUConfig();
if (msConfig == null) {
private Model createLiteModel(String filePath, boolean isResize) {
MSContext msContext = createCPUConfig();
if (msContext == null) {
Log.e(TAG, "Init context failed");
return null;
}
// Create the MindSpore lite session.
LiteSession session = new LiteSession();
boolean ret = session.init(msConfig);
msConfig.free();
if (!ret) {
Log.e(TAG, "Create session failed");
return null;
}
// Create the MindSpore lite model.
Model model = new Model();
// Compile graph.
ret = session.compileGraph(model);
boolean ret = model.build(filePath, ModelType.MT_MINDIR, msContext);
if (!ret) {
session.free();
model.free();
Log.e(TAG, "Compile graph failed");
return null;
}
if (isResize) {
List<MSTensor> inputs = session.getInputs();
List<MSTensor> inputs = model.getInputs();
int[][] dims = {{1, 300, 300, 3}};
ret = session.resize(inputs, dims);
ret = model.resize(inputs, dims);
if (!ret) {
Log.e(TAG, "Resize failed");
session.free();
model.free();
return null;
}
StringBuilder msgSb = new StringBuilder();
msgSb.append("in tensor shape: [");
int[] shape = session.getInputs().get(0).getShape();
int[] shape = model.getInputs().get(0).getShape();
for (int dim : shape) {
msgSb.append(dim).append(",");
}
@ -118,7 +111,7 @@ public class MainActivity extends AppCompatActivity {
Log.i(TAG, msgSb.toString());
}
return session;
return model;
}
private boolean printTensorData(MSTensor outTensor) {
@ -146,9 +139,9 @@ public class MainActivity extends AppCompatActivity {
return true;
}
private boolean runInference(LiteSession session) {
private boolean runInference(Model model) {
Log.i(TAG, "runInference: ");
MSTensor inputTensor = session.getInputsByTensorName("graph_input-173");
MSTensor inputTensor = model.getInputByTensorName("graph_input-173");
if (inputTensor.getDataType() != DataType.kNumberTypeFloat32) {
Log.e(TAG, "Input tensor shape do not float, the data type is " + inputTensor.getDataType());
return false;
@ -160,39 +153,30 @@ public class MainActivity extends AppCompatActivity {
// Set Input Data.
inputTensor.setData(inputData);
session.bindThread(true);
// Run Inference.
boolean ret = session.runGraph();
session.bindThread(false);
boolean ret = model.predict();
if (!ret) {
Log.e(TAG, "MindSpore Lite run failed.");
return false;
}
// Get Output Tensor Data.
MSTensor outTensor = session.getOutputByTensorName("Softmax-65");
MSTensor outTensor = model.getOutputByTensorName("Softmax-65");
// Print out Tensor Data.
ret = printTensorData(outTensor);
if (!ret) {
return false;
}
outTensor = session.getOutputsByNodeName("Softmax-65").get(0);
outTensor = model.getOutputsByNodeName("Softmax-65").get(0);
ret = printTensorData(outTensor);
if (!ret) {
return false;
}
Map<String, MSTensor> outTensors = session.getOutputMapByTensor();
Iterator<Map.Entry<String, MSTensor>> entries = outTensors.entrySet().iterator();
while (entries.hasNext()) {
Map.Entry<String, MSTensor> entry = entries.next();
Log.i(TAG, "Tensor name is:" + entry.getKey());
ret = printTensorData(entry.getValue());
List<MSTensor> outTensors = model.getOutputs();
for (MSTensor output : outTensors) {
Log.i(TAG, "Tensor name is:" + output.tensorName());
ret = printTensorData(output);
if (!ret) {
return false;
}
@ -202,9 +186,8 @@ public class MainActivity extends AppCompatActivity {
}
private void freeBuffer() {
session1.free();
session2.free();
model.free();
model1.free();
model2.free();
}
@ -214,45 +197,33 @@ public class MainActivity extends AppCompatActivity {
setContentView(R.layout.activity_main);
String version = Version.version();
Log.i(TAG, version);
model = new Model();
String modelPath = "mobilenetv2.ms";
boolean ret = model.loadModel(this.getApplicationContext(), modelPath);
if (!ret) {
Log.e(TAG, "Load model failed, model is " + modelPath);
model1 = createLiteModel(modelPath, false);
if (model1 != null) {
model1Compile = true;
} else {
session1 = createLiteSession(false);
if (session1 != null) {
session1Compile = true;
} else {
Toast.makeText(getApplicationContext(), "session1 Compile Failed.",
Toast.LENGTH_SHORT).show();
}
session2 = createLiteSession(true);
if (session2 != null) {
session2Compile = true;
} else {
Toast.makeText(getApplicationContext(), "session2 Compile Failed.",
Toast.LENGTH_SHORT).show();
}
Toast.makeText(getApplicationContext(), "model1 Compile Failed.",
Toast.LENGTH_SHORT).show();
}
if (model != null) {
// Note: when use model.freeBuffer(), the model can not be compiled again.
model.freeBuffer();
model2 = createLiteModel(modelPath, true);
if (model2 != null) {
model2Compile = true;
} else {
Toast.makeText(getApplicationContext(), "model2 Compile Failed.",
Toast.LENGTH_SHORT).show();
}
TextView btn_run = findViewById(R.id.btn_run);
btn_run.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
if (session1Finish && session1Compile) {
if (model1Finish && model1Compile) {
new Thread(new Runnable() {
@Override
public void run() {
session1Finish = false;
runInference(session1);
session1Finish = true;
model1Finish = false;
runInference(model1);
model1Finish = true;
}
}).start();
} else {
@ -266,27 +237,27 @@ public class MainActivity extends AppCompatActivity {
new View.OnClickListener() {
@Override
public void onClick(View v) {
if (session1Finish && session1Compile) {
if (model1Finish && model1Compile) {
new Thread(new Runnable() {
@Override
public void run() {
session1Finish = false;
runInference(session1);
session1Finish = true;
model1Finish = false;
runInference(model2);
model1Finish = true;
}
}).start();
}
if (session2Finish && session2Compile) {
if (model2Finish && model2Compile) {
new Thread(new Runnable() {
@Override
public void run() {
session2Finish = false;
runInference(session2);
session2Finish = true;
model2Finish = false;
runInference(model2);
model2Finish = true;
}
}).start();
}
if (!session2Finish && !session2Finish) {
if (!model2Finish && !model2Finish) {
Toast.makeText(getApplicationContext(), "MindSpore Lite is running...",
Toast.LENGTH_SHORT).show();
}

View File

@ -16,29 +16,28 @@
package com.mindspore.lite.train_lenet;
import com.mindspore.lite.MSTensor;
import com.mindspore.lite.LiteSession;
import com.mindspore.lite.TrainSession;
import com.mindspore.lite.config.MSConfig;
import com.mindspore.Graph;
import com.mindspore.Model;
import com.mindspore.config.DeviceType;
import com.mindspore.config.MSContext;
import com.mindspore.config.TrainCfg;
import com.mindspore.MSTensor;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Vector;
public class NetRunner {
private int dataIndex = 0;
private int labelIndex = 1;
private LiteSession session;
private Model model;
private int batchSize;
private long dataSize; // one input data size, in byte
private DataSet ds = new DataSet();
private final DataSet ds = new DataSet();
private long numOfClasses;
private long cycles = 2000;
private final long cycles = 2000;
private int idx = 1;
private int virtualBatch = 16;
private String trainedFilePath = "trained.ms";
private ByteBuffer imageInputBuf;
private ByteBuffer labelInputBuf;
private int imageBatchElements;
@ -46,37 +45,19 @@ public class NetRunner {
private MSTensor labelTensor;
private int[] targetLabels;
public void initAndFigureInputs(String modelPath, int virtualBatchSize) {
MSConfig msConfig = new MSConfig();
// arg 0: DeviceType:DT_CPU -> 0
// arg 1: ThreadNum -> 2
// arg 2: cpuBindMode:NO_BIND -> 0
// arg 3: enable_fp16 -> false
msConfig.init(0, 2, 0, false);
session = new LiteSession();
System.out.println("Model path is " + modelPath);
session = TrainSession.createTrainSession(modelPath, msConfig, false);
virtualBatch = virtualBatchSize;
session.setupVirtualBatch(virtualBatch, 0.01f, 1.00f);
List<MSTensor> inputs = session.getInputs();
private int initInputs() {
List<MSTensor> inputs = model.getInputs();
if (inputs.size() <= 1) {
System.err.println("model input size: " + inputs.size());
return;
return -1;
}
dataIndex = 0;
labelIndex = 1;
int dataIndex = 0;
int labelIndex = 1;
batchSize = inputs.get(dataIndex).getShape()[0];
dataSize = inputs.get(dataIndex).size() / batchSize;
System.out.println("batch_size: " + batchSize);
System.out.println("virtual batch multiplier: " + virtualBatch);
int index = modelPath.lastIndexOf(".ms");
if (index == -1) {
System.out.println("The model " + modelPath + " should be named *.ms");
return;
}
trainedFilePath = modelPath.substring(0, index) + "_trained.ms";
imageTensor = inputs.get(dataIndex);
imageInputBuf = ByteBuffer.allocateDirect((int) imageTensor.size());
@ -87,6 +68,46 @@ public class NetRunner {
labelInputBuf = ByteBuffer.allocateDirect((int) labelTensor.size());
labelInputBuf.order(ByteOrder.nativeOrder());
targetLabels = new int[batchSize];
return 0;
}
public int initAndFigureInputs(String modelPath, int virtualBatchSize) {
System.out.println("Model path is " + modelPath);
MSContext context = new MSContext();
// use default param init context
context.init();
boolean isSuccess = context.addDeviceInfo(DeviceType.DT_CPU, false, 0);
if (!isSuccess) {
System.err.println("Load graph failed");
context.free();
return -1;
}
TrainCfg trainCfg = new TrainCfg();
isSuccess = trainCfg.init();
if (!isSuccess) {
System.err.println("Init train config failed");
context.free();
trainCfg.free();
return -1;
}
model = new Model();
Graph graph = new Graph();
isSuccess = graph.load(modelPath);
if (!isSuccess) {
System.err.println("Load graph failed");
graph.free();
context.free();
trainCfg.free();
return -1;
}
isSuccess = model.build(graph, context, trainCfg);
if (!isSuccess) {
System.err.println("Build model failed");
return -1;
}
virtualBatch = virtualBatchSize;
model.setupVirtualBatch(virtualBatch, 0.01f, 1.00f);
return initInputs();
}
public int initDB(String datasetPath) {
@ -110,12 +131,16 @@ public class NetRunner {
public float getLoss() {
MSTensor tensor = searchOutputsForSize(1);
if (tensor == null) {
System.err.println("get loss tensor failed");
return Float.NaN;
}
return tensor.getFloatData()[0];
}
private MSTensor searchOutputsForSize(int size) {
Map<String, MSTensor> outputs = session.getOutputMapByTensor();
for (MSTensor tensor : outputs.values()) {
List<MSTensor> outputs = model.getOutputs();
for (MSTensor tensor : outputs) {
if (tensor.elementsNum() == size) {
return tensor;
}
@ -125,13 +150,23 @@ public class NetRunner {
}
public int trainLoop() {
session.train();
boolean isSuccess = model.setTrainMode(true);
if (!isSuccess) {
model.free();
System.err.println("set train mode failed");
return -1;
}
float min_loss = 1000;
float max_acc = 0;
for (int i = 0; i < cycles; i++) {
for (int b = 0; b < virtualBatch; b++) {
fillInputData(ds.getTrainData(), false);
session.runGraph();
isSuccess = model.runStep();
if (!isSuccess) {
model.free();
System.err.println("run step failed");
return -1;
}
float loss = getLoss();
if (min_loss > loss) {
min_loss = loss;
@ -156,13 +191,14 @@ public class NetRunner {
if (maxTests != -1 && tests < maxTests) {
tests = maxTests;
}
session.eval();
model.setTrainMode(false);
for (long i = 0; i < tests; i++) {
int[] labels = fillInputData(test_set, (maxTests == -1));
session.runGraph();
model.predict();
MSTensor outputsv = searchOutputsForSize((int) (batchSize * numOfClasses));
if (outputsv == null) {
System.err.println("can not find output tensor with size: " + batchSize * numOfClasses);
model.free();
System.exit(1);
}
float[] scores = outputsv.getFloatData();
@ -181,7 +217,7 @@ public class NetRunner {
}
}
}
session.train();
model.setTrainMode(true);
accuracy /= (batchSize * tests);
return accuracy;
}
@ -212,27 +248,48 @@ public class NetRunner {
}
public void trainModel(String modelPath, String datasetPath, int virtualBatch) {
int index = modelPath.lastIndexOf(".ms");
if (index == -1) {
System.err.println("The model " + modelPath + " should be named *.ms");
return;
}
System.out.println("==========Loading Model, Create Train Session=============");
initAndFigureInputs(modelPath, virtualBatch);
int ret = initAndFigureInputs(modelPath, virtualBatch);
if (ret != 0) {
System.out.println("==========Init and figure inputs failed================");
model.free();
return;
}
System.out.println("==========Initing DataSet================");
initDB(datasetPath);
ret = initDB(datasetPath);
if (ret != 0) {
System.out.println("==========Init dataset failed================");
return;
}
System.out.println("==========Training Model===================");
trainLoop();
ret = trainLoop();
if (ret != 0) {
System.out.println("==========Init dataset failed================");
model.free();
return;
}
System.out.println("==========Evaluating The Trained Model============");
float acc = calculateAccuracy(-1);
System.out.println("accuracy = " + acc);
if (cycles > 0) {
// arg 0: FileName
// arg 1: model type MT_TRAIN -> 0
// arg 2: quantization type QT_DEFAULT -> 0
if (session.export(trainedFilePath, 0, 0)) {
// arg 1: quantization type QT_DEFAULT -> 0
// arg 2: model type MT_TRAIN -> 0
// arg 3: use default output tensor names
String trainedFilePath = modelPath.substring(0, index) + "_trained.ms";
if (model.export(trainedFilePath, 0, false, new ArrayList<>())) {
System.out.println("Trained model successfully saved: " + trainedFilePath);
} else {
System.err.println("Save model error.");
}
}
session.free();
model.free();
}
}

View File

@ -219,6 +219,34 @@ public class Model {
return new MSTensor(tensorAddr);
}
/**
* Get output tensors by node name.
*
* @param nodeName output node name
* @return output tensor
*/
public List<MSTensor> getOutputsByNodeName(String nodeName) {
if (nodeName == null) {
return null;
}
List<Long> ret = this.getOutputsByNodeName(this.modelPtr, nodeName);
List<MSTensor> tensors = new ArrayList<>();
for (Long msTensorAddr : ret) {
MSTensor msTensor = new MSTensor(msTensorAddr);
tensors.add(msTensor);
}
return tensors;
}
/**
* Get output tensor names.
*
* @return output tensor name list.
*/
public List<String> getOutputTensorNames() {
return this.getOutputTensorNames(this.modelPtr);
}
/**
* Export the model.
*
@ -230,14 +258,17 @@ public class Model {
*/
public boolean export(String fileName, int quantizationType, boolean isOnlyExportInfer,
List<String> outputTensorNames) {
if (fileName == null || outputTensorNames == null) {
if (fileName == null) {
return false;
}
String[] outputTensorArray = new String[outputTensorNames.size()];
for (int i = 0; i < outputTensorNames.size(); i++) {
outputTensorArray[i] = outputTensorNames.get(i);
if (outputTensorNames != null) {
String[] outputTensorArray = new String[outputTensorNames.size()];
for (int i = 0; i < outputTensorNames.size(); i++) {
outputTensorArray[i] = outputTensorNames.get(i);
}
return export(modelPtr, fileName, quantizationType, isOnlyExportInfer, outputTensorArray);
}
return export(modelPtr, fileName, quantizationType, isOnlyExportInfer, outputTensorArray);
return export(modelPtr, fileName, quantizationType, isOnlyExportInfer, null);
}
/**
@ -291,6 +322,28 @@ public class Model {
return this.getTrainMode(modelPtr);
}
/**
* set learning rate.
*
* @param learning_rate learning rate.
* @return Whether the set learning rate is successful.
*/
public boolean setLearningRate(float learning_rate) {
return this.setLearningRate(this.modelPtr, learning_rate);
}
/**
* Set the virtual batch.
*
* @param virtualBatchMultiplier virtual batch multuplier.
* @param learningRate learning rate.
* @param momentum monentum.
* @return Whether the virtual batch is successfully set.
*/
public boolean setupVirtualBatch(int virtualBatchMultiplier, float learningRate, float momentum) {
return this.setupVirtualBatch(this.modelPtr, virtualBatchMultiplier, learningRate, momentum);
}
/**
* Free model
*/
@ -317,6 +370,10 @@ public class Model {
private native long getOutputByTensorName(long modelPtr, String tensorName);
private native List<String> getOutputTensorNames(long modelPtr);
private native List<Long> getOutputsByNodeName(long modelPtr, String nodeName);
private native boolean setTrainMode(long modelPtr, boolean isTrain);
private native boolean getTrainMode(long modelPtr);
@ -329,4 +386,9 @@ public class Model {
private native List<Long> getFeatureMaps(long modelPtr);
private native boolean updateFeatureMaps(long modelPtr, long[] newFeatures);
private native boolean setLearningRate(long modelPtr, float learning_rate);
private native boolean setupVirtualBatch(long modelPtr, int virtualBatchMultiplier, float learningRate,
float momentum);
}

View File

@ -258,12 +258,60 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_getOutputByTensorNam
return GetTensorByInOutName(env, model_ptr, tensor_name, false);
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getOutputTensorNames(JNIEnv *env, jobject thiz,
jlong model_ptr) {
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jobject ret = env->NewObject(array_list, array_list_construct);
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
auto *pointer = reinterpret_cast<void *>(model_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return ret;
}
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
auto output_names = lite_model_ptr->GetOutputTensorNames();
for (const auto &output_name : output_names) {
env->CallBooleanMethod(ret, array_list_add, env->NewStringUTF(output_name.c_str()));
}
return ret;
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getOutputsByNodeName(JNIEnv *env, jobject thiz,
jlong model_ptr, jstring node_name) {
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jobject ret = env->NewObject(array_list, array_list_construct);
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
jclass long_object = env->FindClass("java/lang/Long");
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
auto *pointer = reinterpret_cast<void *>(model_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return ret;
}
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
auto tensors = lite_model_ptr->GetOutputsByNodeName(env->GetStringUTFChars(node_name, JNI_FALSE));
for (auto &tensor : tensors) {
auto tensor_ptr = std::make_unique<mindspore::MSTensor>(tensor);
if (tensor_ptr == nullptr) {
MS_LOGE("Make ms tensor failed");
return ret;
}
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(tensor_ptr.release()));
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
}
return ret;
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_getTrainMode(JNIEnv *env, jobject thiz,
jlong model_ptr) {
auto *pointer = reinterpret_cast<void *>(model_ptr);
if (pointer == nullptr) {
MS_LOGE("Model pointer from java is nullptr");
return jlong(false);
return (jboolean) false;
}
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
return static_cast<jboolean>(lite_model_ptr->GetTrainMode());
@ -285,7 +333,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_runStep(JNIEnv *e
auto *pointer = reinterpret_cast<void *>(model_ptr);
if (pointer == nullptr) {
MS_LOGE("Model pointer from java is nullptr");
return jlong(false);
return (jboolean) false;
}
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
auto status = lite_model_ptr->RunStep(nullptr, nullptr);
@ -313,7 +361,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_predict(JNIEnv *e
auto *pointer = reinterpret_cast<void *>(model_ptr);
if (pointer == nullptr) {
MS_LOGE("Model pointer from java is nullptr");
return jlong(false);
return (jboolean) false;
}
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
auto c_inputs = convertArrayToVector(env, inputs);
@ -328,7 +376,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_resize(JNIEnv *en
auto *pointer = reinterpret_cast<void *>(model_ptr);
if (pointer == nullptr) {
MS_LOGE("Model pointer from java is nullptr");
return false;
return (jboolean) false;
}
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
@ -339,7 +387,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_resize(JNIEnv *en
auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]);
if (tensor_pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return false;
return (jboolean) false;
}
auto &ms_tensor = *static_cast<mindspore::MSTensor *>(tensor_pointer);
c_inputs.push_back(ms_tensor);
@ -403,7 +451,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_updateFeatureMaps
auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]);
if (tensor_pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return false;
return (jboolean) false;
}
auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(tensor_pointer);
newFeatures.emplace_back(*ms_tensor_ptr);
@ -441,6 +489,30 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getFeatureMaps(JNI
return ret;
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_setLearningRate(JNIEnv *env, jclass, jlong model_ptr,
jfloat learning_rate) {
auto *pointer = reinterpret_cast<void *>(model_ptr);
if (pointer == nullptr) {
MS_LOGE("Model pointer from java is nullptr");
return (jboolean) false;
}
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
auto ret = lite_model_ptr->SetLearningRate(learning_rate);
return (jboolean)(ret.IsOk());
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_setupVirtualBatch(
JNIEnv *env, jobject thiz, jlong model_ptr, jint virtual_batch_factor, jfloat learning_rate, jfloat momentum) {
auto *pointer = reinterpret_cast<void *>(model_ptr);
if (pointer == nullptr) {
MS_LOGE("Model pointer from java is nullptr");
return (jboolean) false;
}
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
auto ret = lite_model_ptr->SetupVirtualBatch(virtual_batch_factor, learning_rate, momentum);
return (jboolean)(ret.IsOk());
}
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_Model_free(JNIEnv *env, jobject thiz, jlong model_ptr) {
auto *pointer = reinterpret_cast<void *>(model_ptr);
if (pointer == nullptr) {

View File

@ -34,14 +34,12 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_config_MSContext_createMSC
return (jlong)context;
}
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_addDeviceInfo(JNIEnv *env, jobject thiz,
jlong context_ptr, jint device_type,
jboolean enable_fp16,
jint npu_freq) {
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_config_MSContext_addDeviceInfo(
JNIEnv *env, jobject thiz, jlong context_ptr, jint device_type, jboolean enable_fp16, jint npu_freq) {
auto *pointer = reinterpret_cast<void *>(context_ptr);
if (pointer == nullptr) {
MS_LOGE("Context pointer from java is nullptr");
return;
return (jboolean) false;
}
auto *c_context_ptr = static_cast<mindspore::Context *>(pointer);
auto &device_list = c_context_ptr->MutableDeviceInfo();
@ -51,7 +49,7 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_addDeviceI
if (cpu_device_info == nullptr) {
MS_LOGE("cpu device info is nullptr");
delete (c_context_ptr);
return;
return (jboolean) false;
}
cpu_device_info->SetEnableFP16(enable_fp16);
device_list.push_back(cpu_device_info);
@ -63,7 +61,7 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_addDeviceI
if (gpu_device_info == nullptr) {
MS_LOGE("gpu device info is nullptr");
delete (c_context_ptr);
return;
return (jboolean) false;
}
gpu_device_info->SetEnableFP16(enable_fp16);
device_list.push_back(gpu_device_info);
@ -75,7 +73,7 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_addDeviceI
if (npu_device_info == nullptr) {
MS_LOGE("npu device info is nullptr");
delete (c_context_ptr);
return;
return (jboolean) false;
}
npu_device_info->SetFrequency(npu_freq);
device_list.push_back(npu_device_info);
@ -84,7 +82,9 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_addDeviceI
default:
MS_LOGE("Invalid device_type : %d", device_type);
delete (c_context_ptr);
return (jboolean) false;
}
return (jboolean) true;
}
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_free(JNIEnv *env, jobject thiz,

View File

@ -38,6 +38,10 @@ public class ModelTest {
Model liteModel = new Model();
boolean isSuccess = liteModel.build(g, context, cfg);
assertTrue(isSuccess);
isSuccess = liteModel.setLearningRate(1.0f);
assertTrue(isSuccess);
isSuccess = liteModel.setupVirtualBatch(2,1.0f,0.5f);
assertTrue(isSuccess);
liteModel.free();
}
@ -145,9 +149,8 @@ public class ModelTest {
for (MSTensor output : outputs) {
System.out.println("output-------" + output.tensorName());
}
System.out.println("");
MSTensor output = liteModel.getOutputByTensorName("Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense" +
"/BiasAdd-op121");
String outputTensorName = "Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/BiasAdd-op121";
MSTensor output = liteModel.getOutputByTensorName(outputTensorName);
assertEquals(80, output.size());
output = liteModel.getOutputByTensorName("Default/network-WithLossCell/_loss_fn-L1Loss/ReduceMean-op112");
assertEquals(0, output.size());
@ -155,6 +158,13 @@ public class ModelTest {
for (MSTensor input : inputs) {
System.out.println(input.tensorName());
}
for (String name : liteModel.getOutputTensorNames()) {
System.out.println("output tensor name:" + name);
}
List<MSTensor> outputTensors = liteModel.getOutputsByNodeName("Default/network-WithLossCell/_backbone-LeNet5" +
"/fc3-Dense/MatMul-op118");
assertEquals(1, outputTensors.size());
assertEquals(outputTensorName, outputTensors.get(0).tensorName());
liteModel.free();
}