add train lenet java demo

This commit is contained in:
xutianchun 2021-03-24 09:17:51 +08:00
parent 197db11fe4
commit a967aab85e
6 changed files with 471 additions and 9 deletions

View File

@ -671,6 +671,12 @@ build_lite_java_arm64() {
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/native/libs/arm64-v8a/
else
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/
@ -697,6 +703,12 @@ build_lite_java_arm32() {
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/native/libs/armeabi-v7a/
else
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/
@ -706,10 +718,15 @@ build_lite_java_arm32() {
build_lite_java_x86() {
# build mindspore-lite x86
local inference_or_train=inference
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
inference_or_train=train
fi
if [[ "$X86_64_SIMD" == "sse" || "$X86_64_SIMD" == "avx" ]]; then
local JTARBALL=mindspore-lite-${VERSION_STR}-inference-linux-x64-${X86_64_SIMD}
local JTARBALL=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-${X86_64_SIMD}
else
local JTARBALL=mindspore-lite-${VERSION_STR}-inference-linux-x64
local JTARBALL=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64
fi
if [[ "X$INC_BUILD" == "Xoff" ]] || [[ ! -f "${BASEPATH}/mindspore/lite/build/java/${JTARBALL}.tar.gz" ]]; then
build_lite "x86_64" "off" ""
@ -721,8 +738,20 @@ build_lite_java_x86() {
[ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/linux_x86/libs/
mkdir -p ${JAVA_PATH}/java/linux_x86/libs/
mkdir -p ${JAVA_PATH}/native/libs/linux_x86/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/linux_x86/libs/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/linux_x86/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so* ${JAVA_PATH}/java/linux_x86/libs/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so* ${JAVA_PATH}/native/libs/linux_x86/
else
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/
fi
[ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL}
}
build_jni_arm64() {
@ -776,7 +805,7 @@ build_jni_x86_64() {
mkdir -pv java/jni
cd java/jni
cmake -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \
-DENABLE_VERBOSE=${ENABLE_VERBOSE} "${JAVA_PATH}/native/"
-DENABLE_VERBOSE=${ENABLE_VERBOSE} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} "${JAVA_PATH}/native/"
make -j$THREAD_NUM
if [[ $? -ne 0 ]]; then
echo "---------------- mindspore lite: build jni x86_64 failed----------------"
@ -825,11 +854,16 @@ build_java() {
cd ${JAVA_PATH}/java/app/build
zip -r mindspore-lite-maven-${VERSION_STR}.zip mindspore
local inference_or_train=inference
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
inference_or_train=train
fi
# build linux x86 jar
if [[ "$X86_64_SIMD" == "sse" || "$X86_64_SIMD" == "avx" ]]; then
local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-inference-linux-x64-${X86_64_SIMD}-jar
local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-${X86_64_SIMD}-jar
else
local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-inference-linux-x64-jar
local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-jar
fi
check_java_home
build_lite_java_x86
@ -843,15 +877,17 @@ build_java() {
gradle releaseJar
# install and package
mkdir -p ${JAVA_PATH}/java/linux_x86/build/lib
cp ${JAVA_PATH}/java/linux_x86/libs/*.so ${JAVA_PATH}/java/linux_x86/build/lib/jar
cp ${JAVA_PATH}/java/linux_x86/libs/*.so* ${JAVA_PATH}/java/linux_x86/build/lib/jar
cd ${JAVA_PATH}/java/linux_x86/build/
cp -r ${JAVA_PATH}/java/linux_x86/build/lib ${JAVA_PATH}/java/linux_x86/build/${LINUX_X86_PACKAGE_NAME}
tar czvf ${LINUX_X86_PACKAGE_NAME}.tar.gz ${LINUX_X86_PACKAGE_NAME}
# copy output
cp ${JAVA_PATH}/java/app/build/mindspore-lite-maven-${VERSION_STR}.zip ${BASEPATH}/output
cp ${LINUX_X86_PACKAGE_NAME}.tar.gz ${BASEPATH}/output
cd ${BASEPATH}/output
[ -n "${VERSION_STR}" ] && rm -rf ${BASEPATH}/mindspore/lite/build/java/mindspore-lite-${VERSION_STR}-inference-linux-x64
[ -n "${VERSION_STR}" ] && rm -rf ${BASEPATH}/mindspore/lite/build/java/mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64
exit 0
}

View File

@ -0,0 +1,55 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.mindspore.lite.demo</groupId>
<artifactId>train_lenet_java</artifactId>
<version>1.0</version>
<properties>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>com.mindspore.lite</groupId>
<artifactId>mindspore-lite-java</artifactId>
<version>1.0</version>
<scope>system</scope>
<systemPath>${project.basedir}/lib/mindspore-lite-java.jar</systemPath>
</dependency>
</dependencies>
<build>
<finalName>${project.name}</finalName>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<configuration>
<archive>
<manifest>
<mainClass>com.mindspore.lite.train_lenet.Main</mainClass>
</manifest>
</archive>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
</configuration>
<executions>
<execution>
<id>make-assemble</id>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>

View File

@ -0,0 +1,131 @@
package com.mindspore.lite.train_lenet;
import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.Vector;
public class DataSet {
private long numOfClasses = 0;
private long expectedDataSize = 0;
public class DataLabelTuple {
public float[] data;
public int label;
}
Vector<DataLabelTuple> trainData;
Vector<DataLabelTuple> testData;
public void initializeMNISTDatabase(String dpath) {
numOfClasses = 10;
trainData = new Vector<DataLabelTuple>();
testData = new Vector<DataLabelTuple>();
readMNISTFile(dpath + "/train/train-images-idx3-ubyte", dpath+"/train/train-labels-idx1-ubyte", trainData);
readMNISTFile(dpath + "/test/t10k-images-idx3-ubyte", dpath+"/test/t10k-labels-idx1-ubyte", testData);
System.out.println("train data cnt: " + trainData.size());
System.out.println("test data cnt: " + testData.size());
}
private String bytesToHex(byte[] bytes) {
StringBuffer sb = new StringBuffer();
for (int i = 0; i < bytes.length; i++) {
String hex = Integer.toHexString(bytes[i] & 0xFF);
if (hex.length() < 2) {
sb.append(0);
}
sb.append(hex);
}
return sb.toString();
}
private void readFile(BufferedInputStream inputStream, byte[] bytes, int len) throws IOException {
int result = inputStream.read(bytes, 0, len);
if (result != len) {
System.err.println("expected read " + len + " bytes, but " + result + " read");
System.exit(1);
}
}
public void readMNISTFile(String inputFileName, String labelFileName, Vector<DataLabelTuple> dataset) {
try {
BufferedInputStream ibin = new BufferedInputStream(new FileInputStream(inputFileName));
BufferedInputStream lbin = new BufferedInputStream(new FileInputStream(labelFileName));
byte[] bytes = new byte[4];
readFile(ibin, bytes, 4);
if (!"00000803".equals(bytesToHex(bytes))) { // 2051
System.err.println("The dataset is not valid: " + bytesToHex(bytes));
return;
}
readFile(ibin, bytes, 4);
int inumber = Integer.parseInt(bytesToHex(bytes), 16);
readFile(lbin, bytes, 4);
if (!"00000801".equals(bytesToHex(bytes))) { // 2049
System.err.println("The dataset label is not valid: " + bytesToHex(bytes));
return;
}
readFile(lbin, bytes, 4);
int lnumber = Integer.parseInt(bytesToHex(bytes), 16);
if (inumber != lnumber) {
System.err.println("input data cnt: " + inumber + " not equal label cnt: " + lnumber);
return;
}
// read all labels
byte[] labels = new byte[lnumber];
readFile(lbin, labels, lnumber);
// row, column
readFile(ibin, bytes, 4);
int n_rows = Integer.parseInt(bytesToHex(bytes), 16);
readFile(ibin, bytes, 4);
int n_cols = Integer.parseInt(bytesToHex(bytes), 16);
if (n_rows != 28 || n_cols != 28) {
System.err.println("invalid n_rows: " + n_rows + " n_cols: " + n_cols);
return;
}
// read images
int image_size = n_rows * n_cols;
byte[] image_data = new byte[image_size];
for (int i = 0; i < lnumber; i++) {
float [] hwc_bin_image = new float[32 * 32];
readFile(ibin, image_data, image_size);
for (int r = 0; r < 32; r++) {
for (int c = 0; c < 32; c++) {
int index = r * 32 + c;
if (r < 2 || r > 29 || c < 2 || c > 29) {
hwc_bin_image[index] = 0;
} else {
int data = image_data[(r-2)*28 + (c-2)] & 0xff;
hwc_bin_image[index] = (float)data / 255.0f;
}
}
}
DataLabelTuple data_label_tupel = new DataLabelTuple();
data_label_tupel.data = hwc_bin_image;
data_label_tupel.label = labels[i] & 0xff;
dataset.add(data_label_tupel);
}
} catch (IOException e) {
System.err.println("Read Dateset exception");
}
}
public void setExpectedDataSize(long data_size) {
expectedDataSize = data_size;
}
public long getNumOfClasses() {
return numOfClasses;
}
public Vector<DataLabelTuple> getTestData() {
return testData;
}
public Vector<DataLabelTuple> getTrainData() {
return trainData;
}
}

View File

@ -0,0 +1,20 @@
package com.mindspore.lite.train_lenet;
import com.mindspore.lite.Version;
public class Main {
public static void main(String[] args) {
System.out.println(Version.version());
if (args.length < 2) {
System.err.println("model path and dataset path must be provided.");
return;
}
String modelPath = args[0];
String datasetPath = args[1];
NetRunner net_runner = new NetRunner();
net_runner.trainModel(modelPath, datasetPath);
}
}

View File

@ -0,0 +1,220 @@
package com.mindspore.lite.train_lenet;
import com.mindspore.lite.MSTensor;
import com.mindspore.lite.TrainSession;
import com.mindspore.lite.config.MSConfig;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.List;
import java.util.Map;
import java.util.Vector;
public class NetRunner {
private int dataIndex = 0;
private int labelIndex = 1;
private TrainSession session;
private long batchSize;
private long dataSize; // one input data size, in byte
private DataSet ds = new DataSet();
private long numOfClasses;
private long cycles = 3500;
private int idx = 1;
private String trainedFilePath = "trained.ms";
public void initAndFigureInputs(String modelPath) {
MSConfig msConfig = new MSConfig();
// arg 0: DeviceType:DT_CPU -> 0
// arg 1: ThreadNum -> 2
// arg 2: cpuBindMode:NO_BIND -> 0
// arg 3: enable_fp16 -> false
msConfig.init(0, 2, 0, false);
session = new TrainSession();
session.init(modelPath, msConfig);
session.setLearningRate(0.01f);
List<MSTensor> inputs = session.getInputs();
if (inputs.size() <= 1) {
System.err.println("model input size: " + inputs.size());
return;
}
dataIndex = 0;
labelIndex = 1;
batchSize = inputs.get(dataIndex).getShape()[0];
dataSize = inputs.get(dataIndex).size() / batchSize;
System.out.println("batch_size: " + batchSize);
int index = modelPath.lastIndexOf(".ms");
if (index == -1) {
System.out.println("The model " + modelPath + " should be named *.ms");
return;
}
trainedFilePath = modelPath.substring(0, index) + "_trained.ms";
}
public int initDB(String datasetPath) {
if (dataSize != 0) {
ds.setExpectedDataSize(dataSize);
}
ds.initializeMNISTDatabase(datasetPath);
numOfClasses = ds.getNumOfClasses();
if (numOfClasses != 10) {
System.err.println("unexpected num_of_class: " + numOfClasses);
System.exit(1);
}
if (ds.testData.size() == 0) {
System.err.println("test data size is 0");
return -1;
}
return 0;
}
public float getLoss() {
MSTensor tensor = searchOutputsForSize(1);
return tensor.getFloatData()[0];
}
private MSTensor searchOutputsForSize(int size) {
Map<String, MSTensor> outputs = session.getOutputMapByTensor();
for (MSTensor tensor : outputs.values()) {
if (tensor.elementsNum() == size) {
return tensor;
}
}
System.err.println("can not find output the tensor which element num is " + size);
return null;
}
public int trainLoop() {
session.train();
float min_loss = 1000;
float max_acc = 0;
for (int i = 0; i < cycles; i++) {
fillInputData(ds.getTrainData(), false);
session.runGraph();
float loss = getLoss();
if (min_loss > loss) {
min_loss = loss;
}
if ((i + 1) % 500 == 0) {
float acc = calculateAccuracy(10); // only test 10 batch size
if (max_acc < acc) {
max_acc = acc;
}
System.out.println("step_" + (i + 1) + ": \tLoss is " + loss + " [min=" + min_loss + "]" + " max_accc=" + max_acc);
}
}
return 0;
}
public float calculateAccuracy(long maxTests) {
float accuracy = 0;
Vector<DataSet.DataLabelTuple> test_set = ds.getTestData();
long tests = test_set.size() / batchSize;
if (maxTests != -1 && tests < maxTests) {
tests = maxTests;
}
session.eval();
for (long i = 0; i < tests; i++) {
Vector<Integer> labels = fillInputData(test_set, (maxTests == -1));
if (labels.size() != batchSize) {
System.err.println("unexpected labels size: " + labels.size() + " batch_size size: " + batchSize);
System.exit(1);
}
session.runGraph();
MSTensor outputsv = searchOutputsForSize((int) (batchSize * numOfClasses));
if (outputsv == null) {
System.err.println("can not find output tensor with size: " + batchSize * numOfClasses);
System.exit(1);
}
float[] scores = outputsv.getFloatData();
for (int b = 0; b < batchSize; b++) {
int max_idx = 0;
float max_score = scores[(int) (numOfClasses * b)];
for (int c = 0; c < numOfClasses; c++) {
if (scores[(int) (numOfClasses * b + c)] > max_score) {
max_score = scores[(int) (numOfClasses * b + c)];
max_idx = c;
}
}
if (labels.get(b) == max_idx) {
accuracy += 1.0;
}
}
}
session.train();
accuracy /= (batchSize * tests);
return accuracy;
}
// each time fill batch_size data
Vector<Integer> fillInputData(Vector<DataSet.DataLabelTuple> dataset, boolean serially) {
Vector<Integer> labelsVec = new Vector<Integer>();
int totalSize = dataset.size();
List<MSTensor> inputs = session.getInputs();
int inputDataCnt = inputs.get(dataIndex).elementsNum();
float[] inputBatchData = new float[inputDataCnt];
int labelDataCnt = inputs.get(labelIndex).elementsNum();
int[] labelBatchData = new int[labelDataCnt];
for (int i = 0; i < batchSize; i++) {
if (serially) {
idx = (++idx) % totalSize;
} else {
idx = (int) (Math.random() * totalSize);
}
int label = 0;
DataSet.DataLabelTuple dataLabelTuple = dataset.get(idx);
label = dataLabelTuple.label;
System.arraycopy(dataLabelTuple.data, 0, inputBatchData, (int) (i * dataLabelTuple.data.length), dataLabelTuple.data.length);
labelBatchData[i] = label;
labelsVec.add(label);
}
ByteBuffer byteBuf = ByteBuffer.allocateDirect(inputBatchData.length * Float.BYTES);
byteBuf.order(ByteOrder.nativeOrder());
for (int i = 0; i < inputBatchData.length; i++) {
byteBuf.putFloat(inputBatchData[i]);
}
inputs.get(dataIndex).setData(byteBuf);
ByteBuffer labelByteBuf = ByteBuffer.allocateDirect(labelBatchData.length * 4);
labelByteBuf.order(ByteOrder.nativeOrder());
for (int i = 0; i < labelBatchData.length; i++) {
labelByteBuf.putInt(labelBatchData[i]);
}
inputs.get(labelIndex).setData(labelByteBuf);
return labelsVec;
}
public void trainModel(String modelPath, String datasetPath) {
System.out.println("==========Loading Model, Create Train Session=============");
initAndFigureInputs(modelPath);
System.out.println("==========Initing DataSet================");
initDB(datasetPath);
System.out.println("==========Training Model===================");
trainLoop();
System.out.println("==========Evaluating The Trained Model============");
float acc = calculateAccuracy(-1);
System.out.println("accuracy = " + acc);
if (cycles > 0) {
if (session.saveToFile(trainedFilePath)) {
System.out.println("Trained model successfully saved: " + trainedFilePath);
} else {
System.err.println("Save model error.");
}
}
session.free();
}
}