sync examples to new java api
This commit is contained in:
parent
d5502bab19
commit
940e31860d
|
@ -1,12 +1,12 @@
|
|||
package com.mindspore.lite.demo;
|
||||
|
||||
import com.mindspore.lite.LiteSession;
|
||||
import com.mindspore.lite.MSTensor;
|
||||
import com.mindspore.lite.Model;
|
||||
import com.mindspore.lite.DataType;
|
||||
import com.mindspore.lite.Version;
|
||||
import com.mindspore.lite.config.MSConfig;
|
||||
import com.mindspore.lite.config.DeviceType;
|
||||
import com.mindspore.MSTensor;
|
||||
import com.mindspore.Model;
|
||||
import com.mindspore.config.DataType;
|
||||
import com.mindspore.config.DeviceType;
|
||||
import com.mindspore.config.MSContext;
|
||||
import com.mindspore.config.ModelType;
|
||||
import com.mindspore.config.Version;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
|
@ -15,7 +15,6 @@ import java.util.Random;
|
|||
|
||||
public class Main {
|
||||
private static Model model;
|
||||
private static LiteSession session;
|
||||
|
||||
public static float[] generateArray(int len) {
|
||||
Random rand = new Random();
|
||||
|
@ -37,28 +36,20 @@ public class Main {
|
|||
return buffer;
|
||||
}
|
||||
|
||||
private static boolean compile() {
|
||||
MSConfig msConfig = new MSConfig();
|
||||
// You can set config through Init Api or use the default parameters directly.
|
||||
// The default parameter is that the backend type is DeviceType.DT_CPU, and the number of threads is 2.
|
||||
boolean ret = msConfig.init(DeviceType.DT_CPU, 2);
|
||||
private static boolean compile(String modelPath) {
|
||||
MSContext context = new MSContext();
|
||||
// use default param init context
|
||||
context.init();
|
||||
boolean ret = context.addDeviceInfo(DeviceType.DT_CPU, false, 0);
|
||||
if (!ret) {
|
||||
System.err.println("Init context failed");
|
||||
System.err.println("Compile graph failed");
|
||||
context.free();
|
||||
return false;
|
||||
}
|
||||
|
||||
// Create the MindSpore lite session.
|
||||
session = new LiteSession();
|
||||
ret = session.init(msConfig);
|
||||
msConfig.free();
|
||||
if (!ret) {
|
||||
System.err.println("Create session failed");
|
||||
model.free();
|
||||
return false;
|
||||
}
|
||||
|
||||
model = new Model();
|
||||
// Compile graph.
|
||||
ret = session.compileGraph(model);
|
||||
ret = model.build(modelPath, ModelType.MT_MINDIR, context);
|
||||
if (!ret) {
|
||||
System.err.println("Compile graph failed");
|
||||
model.free();
|
||||
|
@ -68,7 +59,7 @@ public class Main {
|
|||
}
|
||||
|
||||
private static boolean run() {
|
||||
MSTensor inputTensor = session.getInputsByTensorName("graph_input-173");
|
||||
MSTensor inputTensor = model.getInputByTensorName("graph_input-173");
|
||||
if (inputTensor.getDataType() != DataType.kNumberTypeFloat32) {
|
||||
System.err.println("Input tensor shape do not float, the data type is " + inputTensor.getDataType());
|
||||
return false;
|
||||
|
@ -82,14 +73,14 @@ public class Main {
|
|||
inputTensor.setData(inputData);
|
||||
|
||||
// Run Inference.
|
||||
boolean ret = session.runGraph();
|
||||
boolean ret = model.predict();
|
||||
if (!ret) {
|
||||
System.err.println("MindSpore Lite run failed.");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get Output Tensor Data.
|
||||
MSTensor outTensor = session.getOutputByTensorName("Softmax-65");
|
||||
MSTensor outTensor = model.getOutputByTensorName("Softmax-65");
|
||||
|
||||
// Print out Tensor Data.
|
||||
StringBuilder msgSb = new StringBuilder();
|
||||
|
@ -117,7 +108,6 @@ public class Main {
|
|||
}
|
||||
|
||||
private static void freeBuffer() {
|
||||
session.free();
|
||||
model.free();
|
||||
}
|
||||
|
||||
|
@ -128,14 +118,7 @@ public class Main {
|
|||
return;
|
||||
}
|
||||
String modelPath = args[0];
|
||||
model = new Model();
|
||||
|
||||
boolean ret = model.loadModel(modelPath);
|
||||
if (!ret) {
|
||||
System.err.println("Load model failed, model path is " + modelPath);
|
||||
return;
|
||||
}
|
||||
ret = compile();
|
||||
boolean ret = compile(modelPath);
|
||||
if (!ret) {
|
||||
System.err.println("MindSpore Lite compile failed.");
|
||||
return;
|
||||
|
|
|
@ -8,33 +8,29 @@ import android.view.View;
|
|||
import android.widget.TextView;
|
||||
import android.widget.Toast;
|
||||
|
||||
import com.mindspore.lite.DataType;
|
||||
import com.mindspore.lite.LiteSession;
|
||||
import com.mindspore.lite.MSTensor;
|
||||
import com.mindspore.lite.Model;
|
||||
import com.mindspore.lite.config.CpuBindMode;
|
||||
import com.mindspore.lite.config.DeviceType;
|
||||
import com.mindspore.lite.config.MSConfig;
|
||||
import com.mindspore.lite.Version;
|
||||
import com.mindspore.MSTensor;
|
||||
import com.mindspore.Model;
|
||||
import com.mindspore.config.CpuBindMode;
|
||||
import com.mindspore.config.DataType;
|
||||
import com.mindspore.config.MSContext;
|
||||
import com.mindspore.config.ModelType;
|
||||
import com.mindspore.config.Version;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
|
||||
|
||||
public class MainActivity extends AppCompatActivity {
|
||||
private String TAG = "MS_LITE";
|
||||
private Model model;
|
||||
private LiteSession session1;
|
||||
private LiteSession session2;
|
||||
private boolean session1Finish = true;
|
||||
private boolean session2Finish = true;
|
||||
private boolean session1Compile = false;
|
||||
private boolean session2Compile = false;
|
||||
private Model model1;
|
||||
private Model model2;
|
||||
private boolean model1Finish = true;
|
||||
private boolean model2Finish = true;
|
||||
private boolean model1Compile = false;
|
||||
private boolean model2Compile = false;
|
||||
|
||||
public float[] generateArray(int len) {
|
||||
Random rand = new Random();
|
||||
|
@ -56,61 +52,58 @@ public class MainActivity extends AppCompatActivity {
|
|||
return buffer.array();
|
||||
}
|
||||
|
||||
private MSConfig createCPUConfig() {
|
||||
MSConfig msConfig = new MSConfig();
|
||||
boolean ret = msConfig.init(DeviceType.DT_CPU, 2, CpuBindMode.HIGHER_CPU, true);
|
||||
private MSContext createCPUConfig() {
|
||||
MSContext context = new MSContext();
|
||||
context.init(2, CpuBindMode.HIGHER_CPU, false);
|
||||
boolean ret = context.addDeviceInfo(com.mindspore.config.DeviceType.DT_CPU, false, 0);
|
||||
if (!ret) {
|
||||
Log.e(TAG, "Create CPU Config failed.");
|
||||
return null;
|
||||
}
|
||||
return msConfig;
|
||||
return context;
|
||||
}
|
||||
|
||||
private MSConfig createGPUConfig() {
|
||||
MSConfig msConfig = new MSConfig();
|
||||
boolean ret = msConfig.init(DeviceType.DT_GPU, 2, CpuBindMode.MID_CPU, true);
|
||||
private MSContext createGPUConfig() {
|
||||
MSContext context = new MSContext();
|
||||
context.init(2, CpuBindMode.MID_CPU, false);
|
||||
boolean ret = context.addDeviceInfo(com.mindspore.config.DeviceType.DT_GPU, true, 0);
|
||||
if (!ret) {
|
||||
Log.e(TAG, "Create GPU Config failed.");
|
||||
return null;
|
||||
}
|
||||
return msConfig;
|
||||
return context;
|
||||
}
|
||||
|
||||
private LiteSession createLiteSession(boolean isResize) {
|
||||
MSConfig msConfig = createCPUConfig();
|
||||
if (msConfig == null) {
|
||||
private Model createLiteModel(String filePath, boolean isResize) {
|
||||
MSContext msContext = createCPUConfig();
|
||||
if (msContext == null) {
|
||||
Log.e(TAG, "Init context failed");
|
||||
return null;
|
||||
}
|
||||
|
||||
// Create the MindSpore lite session.
|
||||
LiteSession session = new LiteSession();
|
||||
boolean ret = session.init(msConfig);
|
||||
msConfig.free();
|
||||
if (!ret) {
|
||||
Log.e(TAG, "Create session failed");
|
||||
return null;
|
||||
}
|
||||
// Create the MindSpore lite model.
|
||||
Model model = new Model();
|
||||
|
||||
// Compile graph.
|
||||
ret = session.compileGraph(model);
|
||||
boolean ret = model.build(filePath, ModelType.MT_MINDIR, msContext);
|
||||
if (!ret) {
|
||||
session.free();
|
||||
model.free();
|
||||
Log.e(TAG, "Compile graph failed");
|
||||
return null;
|
||||
}
|
||||
|
||||
if (isResize) {
|
||||
List<MSTensor> inputs = session.getInputs();
|
||||
List<MSTensor> inputs = model.getInputs();
|
||||
int[][] dims = {{1, 300, 300, 3}};
|
||||
ret = session.resize(inputs, dims);
|
||||
ret = model.resize(inputs, dims);
|
||||
if (!ret) {
|
||||
Log.e(TAG, "Resize failed");
|
||||
session.free();
|
||||
model.free();
|
||||
return null;
|
||||
}
|
||||
StringBuilder msgSb = new StringBuilder();
|
||||
msgSb.append("in tensor shape: [");
|
||||
int[] shape = session.getInputs().get(0).getShape();
|
||||
int[] shape = model.getInputs().get(0).getShape();
|
||||
for (int dim : shape) {
|
||||
msgSb.append(dim).append(",");
|
||||
}
|
||||
|
@ -118,7 +111,7 @@ public class MainActivity extends AppCompatActivity {
|
|||
Log.i(TAG, msgSb.toString());
|
||||
}
|
||||
|
||||
return session;
|
||||
return model;
|
||||
}
|
||||
|
||||
private boolean printTensorData(MSTensor outTensor) {
|
||||
|
@ -146,9 +139,9 @@ public class MainActivity extends AppCompatActivity {
|
|||
return true;
|
||||
}
|
||||
|
||||
private boolean runInference(LiteSession session) {
|
||||
private boolean runInference(Model model) {
|
||||
Log.i(TAG, "runInference: ");
|
||||
MSTensor inputTensor = session.getInputsByTensorName("graph_input-173");
|
||||
MSTensor inputTensor = model.getInputByTensorName("graph_input-173");
|
||||
if (inputTensor.getDataType() != DataType.kNumberTypeFloat32) {
|
||||
Log.e(TAG, "Input tensor shape do not float, the data type is " + inputTensor.getDataType());
|
||||
return false;
|
||||
|
@ -160,39 +153,30 @@ public class MainActivity extends AppCompatActivity {
|
|||
|
||||
// Set Input Data.
|
||||
inputTensor.setData(inputData);
|
||||
|
||||
session.bindThread(true);
|
||||
// Run Inference.
|
||||
boolean ret = session.runGraph();
|
||||
session.bindThread(false);
|
||||
boolean ret = model.predict();
|
||||
if (!ret) {
|
||||
Log.e(TAG, "MindSpore Lite run failed.");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// Get Output Tensor Data.
|
||||
MSTensor outTensor = session.getOutputByTensorName("Softmax-65");
|
||||
MSTensor outTensor = model.getOutputByTensorName("Softmax-65");
|
||||
// Print out Tensor Data.
|
||||
ret = printTensorData(outTensor);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
|
||||
outTensor = session.getOutputsByNodeName("Softmax-65").get(0);
|
||||
outTensor = model.getOutputsByNodeName("Softmax-65").get(0);
|
||||
ret = printTensorData(outTensor);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Map<String, MSTensor> outTensors = session.getOutputMapByTensor();
|
||||
|
||||
Iterator<Map.Entry<String, MSTensor>> entries = outTensors.entrySet().iterator();
|
||||
while (entries.hasNext()) {
|
||||
Map.Entry<String, MSTensor> entry = entries.next();
|
||||
|
||||
Log.i(TAG, "Tensor name is:" + entry.getKey());
|
||||
ret = printTensorData(entry.getValue());
|
||||
List<MSTensor> outTensors = model.getOutputs();
|
||||
for (MSTensor output : outTensors) {
|
||||
Log.i(TAG, "Tensor name is:" + output.tensorName());
|
||||
ret = printTensorData(output);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
|
@ -202,9 +186,8 @@ public class MainActivity extends AppCompatActivity {
|
|||
}
|
||||
|
||||
private void freeBuffer() {
|
||||
session1.free();
|
||||
session2.free();
|
||||
model.free();
|
||||
model1.free();
|
||||
model2.free();
|
||||
}
|
||||
|
||||
|
||||
|
@ -214,45 +197,33 @@ public class MainActivity extends AppCompatActivity {
|
|||
setContentView(R.layout.activity_main);
|
||||
String version = Version.version();
|
||||
Log.i(TAG, version);
|
||||
model = new Model();
|
||||
String modelPath = "mobilenetv2.ms";
|
||||
boolean ret = model.loadModel(this.getApplicationContext(), modelPath);
|
||||
if (!ret) {
|
||||
Log.e(TAG, "Load model failed, model is " + modelPath);
|
||||
model1 = createLiteModel(modelPath, false);
|
||||
if (model1 != null) {
|
||||
model1Compile = true;
|
||||
} else {
|
||||
session1 = createLiteSession(false);
|
||||
if (session1 != null) {
|
||||
session1Compile = true;
|
||||
} else {
|
||||
Toast.makeText(getApplicationContext(), "session1 Compile Failed.",
|
||||
Toast.LENGTH_SHORT).show();
|
||||
}
|
||||
session2 = createLiteSession(true);
|
||||
if (session2 != null) {
|
||||
session2Compile = true;
|
||||
} else {
|
||||
Toast.makeText(getApplicationContext(), "session2 Compile Failed.",
|
||||
Toast.LENGTH_SHORT).show();
|
||||
}
|
||||
Toast.makeText(getApplicationContext(), "model1 Compile Failed.",
|
||||
Toast.LENGTH_SHORT).show();
|
||||
}
|
||||
|
||||
if (model != null) {
|
||||
// Note: when use model.freeBuffer(), the model can not be compiled again.
|
||||
model.freeBuffer();
|
||||
model2 = createLiteModel(modelPath, true);
|
||||
if (model2 != null) {
|
||||
model2Compile = true;
|
||||
} else {
|
||||
Toast.makeText(getApplicationContext(), "model2 Compile Failed.",
|
||||
Toast.LENGTH_SHORT).show();
|
||||
}
|
||||
|
||||
TextView btn_run = findViewById(R.id.btn_run);
|
||||
btn_run.setOnClickListener(new View.OnClickListener() {
|
||||
@Override
|
||||
public void onClick(View v) {
|
||||
|
||||
if (session1Finish && session1Compile) {
|
||||
if (model1Finish && model1Compile) {
|
||||
new Thread(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
session1Finish = false;
|
||||
runInference(session1);
|
||||
session1Finish = true;
|
||||
model1Finish = false;
|
||||
runInference(model1);
|
||||
model1Finish = true;
|
||||
}
|
||||
}).start();
|
||||
} else {
|
||||
|
@ -266,27 +237,27 @@ public class MainActivity extends AppCompatActivity {
|
|||
new View.OnClickListener() {
|
||||
@Override
|
||||
public void onClick(View v) {
|
||||
if (session1Finish && session1Compile) {
|
||||
if (model1Finish && model1Compile) {
|
||||
new Thread(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
session1Finish = false;
|
||||
runInference(session1);
|
||||
session1Finish = true;
|
||||
model1Finish = false;
|
||||
runInference(model2);
|
||||
model1Finish = true;
|
||||
}
|
||||
}).start();
|
||||
}
|
||||
if (session2Finish && session2Compile) {
|
||||
if (model2Finish && model2Compile) {
|
||||
new Thread(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
session2Finish = false;
|
||||
runInference(session2);
|
||||
session2Finish = true;
|
||||
model2Finish = false;
|
||||
runInference(model2);
|
||||
model2Finish = true;
|
||||
}
|
||||
}).start();
|
||||
}
|
||||
if (!session2Finish && !session2Finish) {
|
||||
if (!model2Finish && !model2Finish) {
|
||||
Toast.makeText(getApplicationContext(), "MindSpore Lite is running...",
|
||||
Toast.LENGTH_SHORT).show();
|
||||
}
|
||||
|
|
|
@ -16,29 +16,28 @@
|
|||
|
||||
package com.mindspore.lite.train_lenet;
|
||||
|
||||
import com.mindspore.lite.MSTensor;
|
||||
import com.mindspore.lite.LiteSession;
|
||||
import com.mindspore.lite.TrainSession;
|
||||
import com.mindspore.lite.config.MSConfig;
|
||||
import com.mindspore.Graph;
|
||||
import com.mindspore.Model;
|
||||
import com.mindspore.config.DeviceType;
|
||||
import com.mindspore.config.MSContext;
|
||||
import com.mindspore.config.TrainCfg;
|
||||
import com.mindspore.MSTensor;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Vector;
|
||||
|
||||
public class NetRunner {
|
||||
private int dataIndex = 0;
|
||||
private int labelIndex = 1;
|
||||
private LiteSession session;
|
||||
private Model model;
|
||||
private int batchSize;
|
||||
private long dataSize; // one input data size, in byte
|
||||
private DataSet ds = new DataSet();
|
||||
private final DataSet ds = new DataSet();
|
||||
private long numOfClasses;
|
||||
private long cycles = 2000;
|
||||
private final long cycles = 2000;
|
||||
private int idx = 1;
|
||||
private int virtualBatch = 16;
|
||||
private String trainedFilePath = "trained.ms";
|
||||
private ByteBuffer imageInputBuf;
|
||||
private ByteBuffer labelInputBuf;
|
||||
private int imageBatchElements;
|
||||
|
@ -46,37 +45,19 @@ public class NetRunner {
|
|||
private MSTensor labelTensor;
|
||||
private int[] targetLabels;
|
||||
|
||||
public void initAndFigureInputs(String modelPath, int virtualBatchSize) {
|
||||
MSConfig msConfig = new MSConfig();
|
||||
// arg 0: DeviceType:DT_CPU -> 0
|
||||
// arg 1: ThreadNum -> 2
|
||||
// arg 2: cpuBindMode:NO_BIND -> 0
|
||||
// arg 3: enable_fp16 -> false
|
||||
msConfig.init(0, 2, 0, false);
|
||||
session = new LiteSession();
|
||||
System.out.println("Model path is " + modelPath);
|
||||
session = TrainSession.createTrainSession(modelPath, msConfig, false);
|
||||
virtualBatch = virtualBatchSize;
|
||||
session.setupVirtualBatch(virtualBatch, 0.01f, 1.00f);
|
||||
|
||||
List<MSTensor> inputs = session.getInputs();
|
||||
private int initInputs() {
|
||||
List<MSTensor> inputs = model.getInputs();
|
||||
if (inputs.size() <= 1) {
|
||||
System.err.println("model input size: " + inputs.size());
|
||||
return;
|
||||
return -1;
|
||||
}
|
||||
|
||||
dataIndex = 0;
|
||||
labelIndex = 1;
|
||||
int dataIndex = 0;
|
||||
int labelIndex = 1;
|
||||
batchSize = inputs.get(dataIndex).getShape()[0];
|
||||
dataSize = inputs.get(dataIndex).size() / batchSize;
|
||||
System.out.println("batch_size: " + batchSize);
|
||||
System.out.println("virtual batch multiplier: " + virtualBatch);
|
||||
int index = modelPath.lastIndexOf(".ms");
|
||||
if (index == -1) {
|
||||
System.out.println("The model " + modelPath + " should be named *.ms");
|
||||
return;
|
||||
}
|
||||
trainedFilePath = modelPath.substring(0, index) + "_trained.ms";
|
||||
|
||||
imageTensor = inputs.get(dataIndex);
|
||||
imageInputBuf = ByteBuffer.allocateDirect((int) imageTensor.size());
|
||||
|
@ -87,6 +68,46 @@ public class NetRunner {
|
|||
labelInputBuf = ByteBuffer.allocateDirect((int) labelTensor.size());
|
||||
labelInputBuf.order(ByteOrder.nativeOrder());
|
||||
targetLabels = new int[batchSize];
|
||||
return 0;
|
||||
}
|
||||
|
||||
public int initAndFigureInputs(String modelPath, int virtualBatchSize) {
|
||||
System.out.println("Model path is " + modelPath);
|
||||
MSContext context = new MSContext();
|
||||
// use default param init context
|
||||
context.init();
|
||||
boolean isSuccess = context.addDeviceInfo(DeviceType.DT_CPU, false, 0);
|
||||
if (!isSuccess) {
|
||||
System.err.println("Load graph failed");
|
||||
context.free();
|
||||
return -1;
|
||||
}
|
||||
TrainCfg trainCfg = new TrainCfg();
|
||||
isSuccess = trainCfg.init();
|
||||
if (!isSuccess) {
|
||||
System.err.println("Init train config failed");
|
||||
context.free();
|
||||
trainCfg.free();
|
||||
return -1;
|
||||
}
|
||||
model = new Model();
|
||||
Graph graph = new Graph();
|
||||
isSuccess = graph.load(modelPath);
|
||||
if (!isSuccess) {
|
||||
System.err.println("Load graph failed");
|
||||
graph.free();
|
||||
context.free();
|
||||
trainCfg.free();
|
||||
return -1;
|
||||
}
|
||||
isSuccess = model.build(graph, context, trainCfg);
|
||||
if (!isSuccess) {
|
||||
System.err.println("Build model failed");
|
||||
return -1;
|
||||
}
|
||||
virtualBatch = virtualBatchSize;
|
||||
model.setupVirtualBatch(virtualBatch, 0.01f, 1.00f);
|
||||
return initInputs();
|
||||
}
|
||||
|
||||
public int initDB(String datasetPath) {
|
||||
|
@ -110,12 +131,16 @@ public class NetRunner {
|
|||
|
||||
public float getLoss() {
|
||||
MSTensor tensor = searchOutputsForSize(1);
|
||||
if (tensor == null) {
|
||||
System.err.println("get loss tensor failed");
|
||||
return Float.NaN;
|
||||
}
|
||||
return tensor.getFloatData()[0];
|
||||
}
|
||||
|
||||
private MSTensor searchOutputsForSize(int size) {
|
||||
Map<String, MSTensor> outputs = session.getOutputMapByTensor();
|
||||
for (MSTensor tensor : outputs.values()) {
|
||||
List<MSTensor> outputs = model.getOutputs();
|
||||
for (MSTensor tensor : outputs) {
|
||||
if (tensor.elementsNum() == size) {
|
||||
return tensor;
|
||||
}
|
||||
|
@ -125,13 +150,23 @@ public class NetRunner {
|
|||
}
|
||||
|
||||
public int trainLoop() {
|
||||
session.train();
|
||||
boolean isSuccess = model.setTrainMode(true);
|
||||
if (!isSuccess) {
|
||||
model.free();
|
||||
System.err.println("set train mode failed");
|
||||
return -1;
|
||||
}
|
||||
float min_loss = 1000;
|
||||
float max_acc = 0;
|
||||
for (int i = 0; i < cycles; i++) {
|
||||
for (int b = 0; b < virtualBatch; b++) {
|
||||
fillInputData(ds.getTrainData(), false);
|
||||
session.runGraph();
|
||||
isSuccess = model.runStep();
|
||||
if (!isSuccess) {
|
||||
model.free();
|
||||
System.err.println("run step failed");
|
||||
return -1;
|
||||
}
|
||||
float loss = getLoss();
|
||||
if (min_loss > loss) {
|
||||
min_loss = loss;
|
||||
|
@ -156,13 +191,14 @@ public class NetRunner {
|
|||
if (maxTests != -1 && tests < maxTests) {
|
||||
tests = maxTests;
|
||||
}
|
||||
session.eval();
|
||||
model.setTrainMode(false);
|
||||
for (long i = 0; i < tests; i++) {
|
||||
int[] labels = fillInputData(test_set, (maxTests == -1));
|
||||
session.runGraph();
|
||||
model.predict();
|
||||
MSTensor outputsv = searchOutputsForSize((int) (batchSize * numOfClasses));
|
||||
if (outputsv == null) {
|
||||
System.err.println("can not find output tensor with size: " + batchSize * numOfClasses);
|
||||
model.free();
|
||||
System.exit(1);
|
||||
}
|
||||
float[] scores = outputsv.getFloatData();
|
||||
|
@ -181,7 +217,7 @@ public class NetRunner {
|
|||
}
|
||||
}
|
||||
}
|
||||
session.train();
|
||||
model.setTrainMode(true);
|
||||
accuracy /= (batchSize * tests);
|
||||
return accuracy;
|
||||
}
|
||||
|
@ -212,27 +248,48 @@ public class NetRunner {
|
|||
}
|
||||
|
||||
public void trainModel(String modelPath, String datasetPath, int virtualBatch) {
|
||||
int index = modelPath.lastIndexOf(".ms");
|
||||
if (index == -1) {
|
||||
System.err.println("The model " + modelPath + " should be named *.ms");
|
||||
return;
|
||||
}
|
||||
System.out.println("==========Loading Model, Create Train Session=============");
|
||||
initAndFigureInputs(modelPath, virtualBatch);
|
||||
int ret = initAndFigureInputs(modelPath, virtualBatch);
|
||||
if (ret != 0) {
|
||||
System.out.println("==========Init and figure inputs failed================");
|
||||
model.free();
|
||||
return;
|
||||
}
|
||||
System.out.println("==========Initing DataSet================");
|
||||
initDB(datasetPath);
|
||||
ret = initDB(datasetPath);
|
||||
if (ret != 0) {
|
||||
System.out.println("==========Init dataset failed================");
|
||||
return;
|
||||
}
|
||||
System.out.println("==========Training Model===================");
|
||||
trainLoop();
|
||||
ret = trainLoop();
|
||||
if (ret != 0) {
|
||||
System.out.println("==========Init dataset failed================");
|
||||
model.free();
|
||||
return;
|
||||
}
|
||||
System.out.println("==========Evaluating The Trained Model============");
|
||||
float acc = calculateAccuracy(-1);
|
||||
System.out.println("accuracy = " + acc);
|
||||
|
||||
if (cycles > 0) {
|
||||
// arg 0: FileName
|
||||
// arg 1: model type MT_TRAIN -> 0
|
||||
// arg 2: quantization type QT_DEFAULT -> 0
|
||||
if (session.export(trainedFilePath, 0, 0)) {
|
||||
// arg 1: quantization type QT_DEFAULT -> 0
|
||||
// arg 2: model type MT_TRAIN -> 0
|
||||
// arg 3: use default output tensor names
|
||||
String trainedFilePath = modelPath.substring(0, index) + "_trained.ms";
|
||||
if (model.export(trainedFilePath, 0, false, new ArrayList<>())) {
|
||||
System.out.println("Trained model successfully saved: " + trainedFilePath);
|
||||
} else {
|
||||
System.err.println("Save model error.");
|
||||
}
|
||||
}
|
||||
session.free();
|
||||
model.free();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -219,6 +219,34 @@ public class Model {
|
|||
return new MSTensor(tensorAddr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get output tensors by node name.
|
||||
*
|
||||
* @param nodeName output node name
|
||||
* @return output tensor
|
||||
*/
|
||||
public List<MSTensor> getOutputsByNodeName(String nodeName) {
|
||||
if (nodeName == null) {
|
||||
return null;
|
||||
}
|
||||
List<Long> ret = this.getOutputsByNodeName(this.modelPtr, nodeName);
|
||||
List<MSTensor> tensors = new ArrayList<>();
|
||||
for (Long msTensorAddr : ret) {
|
||||
MSTensor msTensor = new MSTensor(msTensorAddr);
|
||||
tensors.add(msTensor);
|
||||
}
|
||||
return tensors;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get output tensor names.
|
||||
*
|
||||
* @return output tensor name list.
|
||||
*/
|
||||
public List<String> getOutputTensorNames() {
|
||||
return this.getOutputTensorNames(this.modelPtr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Export the model.
|
||||
*
|
||||
|
@ -230,14 +258,17 @@ public class Model {
|
|||
*/
|
||||
public boolean export(String fileName, int quantizationType, boolean isOnlyExportInfer,
|
||||
List<String> outputTensorNames) {
|
||||
if (fileName == null || outputTensorNames == null) {
|
||||
if (fileName == null) {
|
||||
return false;
|
||||
}
|
||||
String[] outputTensorArray = new String[outputTensorNames.size()];
|
||||
for (int i = 0; i < outputTensorNames.size(); i++) {
|
||||
outputTensorArray[i] = outputTensorNames.get(i);
|
||||
if (outputTensorNames != null) {
|
||||
String[] outputTensorArray = new String[outputTensorNames.size()];
|
||||
for (int i = 0; i < outputTensorNames.size(); i++) {
|
||||
outputTensorArray[i] = outputTensorNames.get(i);
|
||||
}
|
||||
return export(modelPtr, fileName, quantizationType, isOnlyExportInfer, outputTensorArray);
|
||||
}
|
||||
return export(modelPtr, fileName, quantizationType, isOnlyExportInfer, outputTensorArray);
|
||||
return export(modelPtr, fileName, quantizationType, isOnlyExportInfer, null);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -291,6 +322,28 @@ public class Model {
|
|||
return this.getTrainMode(modelPtr);
|
||||
}
|
||||
|
||||
/**
|
||||
* set learning rate.
|
||||
*
|
||||
* @param learning_rate learning rate.
|
||||
* @return Whether the set learning rate is successful.
|
||||
*/
|
||||
public boolean setLearningRate(float learning_rate) {
|
||||
return this.setLearningRate(this.modelPtr, learning_rate);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the virtual batch.
|
||||
*
|
||||
* @param virtualBatchMultiplier virtual batch multuplier.
|
||||
* @param learningRate learning rate.
|
||||
* @param momentum monentum.
|
||||
* @return Whether the virtual batch is successfully set.
|
||||
*/
|
||||
public boolean setupVirtualBatch(int virtualBatchMultiplier, float learningRate, float momentum) {
|
||||
return this.setupVirtualBatch(this.modelPtr, virtualBatchMultiplier, learningRate, momentum);
|
||||
}
|
||||
|
||||
/**
|
||||
* Free model
|
||||
*/
|
||||
|
@ -317,6 +370,10 @@ public class Model {
|
|||
|
||||
private native long getOutputByTensorName(long modelPtr, String tensorName);
|
||||
|
||||
private native List<String> getOutputTensorNames(long modelPtr);
|
||||
|
||||
private native List<Long> getOutputsByNodeName(long modelPtr, String nodeName);
|
||||
|
||||
private native boolean setTrainMode(long modelPtr, boolean isTrain);
|
||||
|
||||
private native boolean getTrainMode(long modelPtr);
|
||||
|
@ -329,4 +386,9 @@ public class Model {
|
|||
private native List<Long> getFeatureMaps(long modelPtr);
|
||||
|
||||
private native boolean updateFeatureMaps(long modelPtr, long[] newFeatures);
|
||||
|
||||
private native boolean setLearningRate(long modelPtr, float learning_rate);
|
||||
|
||||
private native boolean setupVirtualBatch(long modelPtr, int virtualBatchMultiplier, float learningRate,
|
||||
float momentum);
|
||||
}
|
||||
|
|
|
@ -258,12 +258,60 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_getOutputByTensorNam
|
|||
return GetTensorByInOutName(env, model_ptr, tensor_name, false);
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getOutputTensorNames(JNIEnv *env, jobject thiz,
|
||||
jlong model_ptr) {
|
||||
jclass array_list = env->FindClass("java/util/ArrayList");
|
||||
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
|
||||
jobject ret = env->NewObject(array_list, array_list_construct);
|
||||
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
|
||||
|
||||
auto *pointer = reinterpret_cast<void *>(model_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return ret;
|
||||
}
|
||||
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
|
||||
auto output_names = lite_model_ptr->GetOutputTensorNames();
|
||||
for (const auto &output_name : output_names) {
|
||||
env->CallBooleanMethod(ret, array_list_add, env->NewStringUTF(output_name.c_str()));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getOutputsByNodeName(JNIEnv *env, jobject thiz,
|
||||
jlong model_ptr, jstring node_name) {
|
||||
jclass array_list = env->FindClass("java/util/ArrayList");
|
||||
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
|
||||
jobject ret = env->NewObject(array_list, array_list_construct);
|
||||
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
|
||||
|
||||
jclass long_object = env->FindClass("java/lang/Long");
|
||||
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
|
||||
auto *pointer = reinterpret_cast<void *>(model_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return ret;
|
||||
}
|
||||
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
|
||||
auto tensors = lite_model_ptr->GetOutputsByNodeName(env->GetStringUTFChars(node_name, JNI_FALSE));
|
||||
for (auto &tensor : tensors) {
|
||||
auto tensor_ptr = std::make_unique<mindspore::MSTensor>(tensor);
|
||||
if (tensor_ptr == nullptr) {
|
||||
MS_LOGE("Make ms tensor failed");
|
||||
return ret;
|
||||
}
|
||||
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(tensor_ptr.release()));
|
||||
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_getTrainMode(JNIEnv *env, jobject thiz,
|
||||
jlong model_ptr) {
|
||||
auto *pointer = reinterpret_cast<void *>(model_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Model pointer from java is nullptr");
|
||||
return jlong(false);
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
|
||||
return static_cast<jboolean>(lite_model_ptr->GetTrainMode());
|
||||
|
@ -285,7 +333,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_runStep(JNIEnv *e
|
|||
auto *pointer = reinterpret_cast<void *>(model_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Model pointer from java is nullptr");
|
||||
return jlong(false);
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
|
||||
auto status = lite_model_ptr->RunStep(nullptr, nullptr);
|
||||
|
@ -313,7 +361,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_predict(JNIEnv *e
|
|||
auto *pointer = reinterpret_cast<void *>(model_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Model pointer from java is nullptr");
|
||||
return jlong(false);
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
|
||||
auto c_inputs = convertArrayToVector(env, inputs);
|
||||
|
@ -328,7 +376,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_resize(JNIEnv *en
|
|||
auto *pointer = reinterpret_cast<void *>(model_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Model pointer from java is nullptr");
|
||||
return false;
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
|
||||
|
||||
|
@ -339,7 +387,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_resize(JNIEnv *en
|
|||
auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]);
|
||||
if (tensor_pointer == nullptr) {
|
||||
MS_LOGE("Tensor pointer from java is nullptr");
|
||||
return false;
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto &ms_tensor = *static_cast<mindspore::MSTensor *>(tensor_pointer);
|
||||
c_inputs.push_back(ms_tensor);
|
||||
|
@ -403,7 +451,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_updateFeatureMaps
|
|||
auto *tensor_pointer = reinterpret_cast<void *>(input_data[i]);
|
||||
if (tensor_pointer == nullptr) {
|
||||
MS_LOGE("Tensor pointer from java is nullptr");
|
||||
return false;
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto *ms_tensor_ptr = static_cast<mindspore::MSTensor *>(tensor_pointer);
|
||||
newFeatures.emplace_back(*ms_tensor_ptr);
|
||||
|
@ -441,6 +489,30 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_Model_getFeatureMaps(JNI
|
|||
return ret;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_setLearningRate(JNIEnv *env, jclass, jlong model_ptr,
|
||||
jfloat learning_rate) {
|
||||
auto *pointer = reinterpret_cast<void *>(model_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Model pointer from java is nullptr");
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
|
||||
auto ret = lite_model_ptr->SetLearningRate(learning_rate);
|
||||
return (jboolean)(ret.IsOk());
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_Model_setupVirtualBatch(
|
||||
JNIEnv *env, jobject thiz, jlong model_ptr, jint virtual_batch_factor, jfloat learning_rate, jfloat momentum) {
|
||||
auto *pointer = reinterpret_cast<void *>(model_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Model pointer from java is nullptr");
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto *lite_model_ptr = static_cast<mindspore::Model *>(pointer);
|
||||
auto ret = lite_model_ptr->SetupVirtualBatch(virtual_batch_factor, learning_rate, momentum);
|
||||
return (jboolean)(ret.IsOk());
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_Model_free(JNIEnv *env, jobject thiz, jlong model_ptr) {
|
||||
auto *pointer = reinterpret_cast<void *>(model_ptr);
|
||||
if (pointer == nullptr) {
|
||||
|
|
|
@ -34,14 +34,12 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_config_MSContext_createMSC
|
|||
return (jlong)context;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_addDeviceInfo(JNIEnv *env, jobject thiz,
|
||||
jlong context_ptr, jint device_type,
|
||||
jboolean enable_fp16,
|
||||
jint npu_freq) {
|
||||
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_config_MSContext_addDeviceInfo(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr, jint device_type, jboolean enable_fp16, jint npu_freq) {
|
||||
auto *pointer = reinterpret_cast<void *>(context_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Context pointer from java is nullptr");
|
||||
return;
|
||||
return (jboolean) false;
|
||||
}
|
||||
auto *c_context_ptr = static_cast<mindspore::Context *>(pointer);
|
||||
auto &device_list = c_context_ptr->MutableDeviceInfo();
|
||||
|
@ -51,7 +49,7 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_addDeviceI
|
|||
if (cpu_device_info == nullptr) {
|
||||
MS_LOGE("cpu device info is nullptr");
|
||||
delete (c_context_ptr);
|
||||
return;
|
||||
return (jboolean) false;
|
||||
}
|
||||
cpu_device_info->SetEnableFP16(enable_fp16);
|
||||
device_list.push_back(cpu_device_info);
|
||||
|
@ -63,7 +61,7 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_addDeviceI
|
|||
if (gpu_device_info == nullptr) {
|
||||
MS_LOGE("gpu device info is nullptr");
|
||||
delete (c_context_ptr);
|
||||
return;
|
||||
return (jboolean) false;
|
||||
}
|
||||
gpu_device_info->SetEnableFP16(enable_fp16);
|
||||
device_list.push_back(gpu_device_info);
|
||||
|
@ -75,7 +73,7 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_addDeviceI
|
|||
if (npu_device_info == nullptr) {
|
||||
MS_LOGE("npu device info is nullptr");
|
||||
delete (c_context_ptr);
|
||||
return;
|
||||
return (jboolean) false;
|
||||
}
|
||||
npu_device_info->SetFrequency(npu_freq);
|
||||
device_list.push_back(npu_device_info);
|
||||
|
@ -84,7 +82,9 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_addDeviceI
|
|||
default:
|
||||
MS_LOGE("Invalid device_type : %d", device_type);
|
||||
delete (c_context_ptr);
|
||||
return (jboolean) false;
|
||||
}
|
||||
return (jboolean) true;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_config_MSContext_free(JNIEnv *env, jobject thiz,
|
||||
|
|
|
@ -38,6 +38,10 @@ public class ModelTest {
|
|||
Model liteModel = new Model();
|
||||
boolean isSuccess = liteModel.build(g, context, cfg);
|
||||
assertTrue(isSuccess);
|
||||
isSuccess = liteModel.setLearningRate(1.0f);
|
||||
assertTrue(isSuccess);
|
||||
isSuccess = liteModel.setupVirtualBatch(2,1.0f,0.5f);
|
||||
assertTrue(isSuccess);
|
||||
liteModel.free();
|
||||
}
|
||||
|
||||
|
@ -145,9 +149,8 @@ public class ModelTest {
|
|||
for (MSTensor output : outputs) {
|
||||
System.out.println("output-------" + output.tensorName());
|
||||
}
|
||||
System.out.println("");
|
||||
MSTensor output = liteModel.getOutputByTensorName("Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense" +
|
||||
"/BiasAdd-op121");
|
||||
String outputTensorName = "Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/BiasAdd-op121";
|
||||
MSTensor output = liteModel.getOutputByTensorName(outputTensorName);
|
||||
assertEquals(80, output.size());
|
||||
output = liteModel.getOutputByTensorName("Default/network-WithLossCell/_loss_fn-L1Loss/ReduceMean-op112");
|
||||
assertEquals(0, output.size());
|
||||
|
@ -155,6 +158,13 @@ public class ModelTest {
|
|||
for (MSTensor input : inputs) {
|
||||
System.out.println(input.tensorName());
|
||||
}
|
||||
for (String name : liteModel.getOutputTensorNames()) {
|
||||
System.out.println("output tensor name:" + name);
|
||||
}
|
||||
List<MSTensor> outputTensors = liteModel.getOutputsByNodeName("Default/network-WithLossCell/_backbone-LeNet5" +
|
||||
"/fc3-Dense/MatMul-op118");
|
||||
assertEquals(1, outputTensors.size());
|
||||
assertEquals(outputTensorName, outputTensors.get(0).tensorName());
|
||||
liteModel.free();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue