forked from mindspore-Ecosystem/mindspore
add train lenet java demo
This commit is contained in:
parent
197db11fe4
commit
a967aab85e
50
build.sh
50
build.sh
|
@ -671,6 +671,12 @@ build_lite_java_arm64() {
|
|||
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/
|
||||
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/
|
||||
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/native/libs/arm64-v8a/
|
||||
else
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/
|
||||
|
@ -697,6 +703,12 @@ build_lite_java_arm32() {
|
|||
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/
|
||||
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/
|
||||
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/native/libs/armeabi-v7a/
|
||||
else
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/
|
||||
|
@ -706,10 +718,15 @@ build_lite_java_arm32() {
|
|||
|
||||
build_lite_java_x86() {
|
||||
# build mindspore-lite x86
|
||||
local inference_or_train=inference
|
||||
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
|
||||
inference_or_train=train
|
||||
fi
|
||||
|
||||
if [[ "$X86_64_SIMD" == "sse" || "$X86_64_SIMD" == "avx" ]]; then
|
||||
local JTARBALL=mindspore-lite-${VERSION_STR}-inference-linux-x64-${X86_64_SIMD}
|
||||
local JTARBALL=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-${X86_64_SIMD}
|
||||
else
|
||||
local JTARBALL=mindspore-lite-${VERSION_STR}-inference-linux-x64
|
||||
local JTARBALL=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64
|
||||
fi
|
||||
if [[ "X$INC_BUILD" == "Xoff" ]] || [[ ! -f "${BASEPATH}/mindspore/lite/build/java/${JTARBALL}.tar.gz" ]]; then
|
||||
build_lite "x86_64" "off" ""
|
||||
|
@ -721,8 +738,20 @@ build_lite_java_x86() {
|
|||
[ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/linux_x86/libs/
|
||||
mkdir -p ${JAVA_PATH}/java/linux_x86/libs/
|
||||
mkdir -p ${JAVA_PATH}/native/libs/linux_x86/
|
||||
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/
|
||||
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/linux_x86/libs/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/linux_x86/
|
||||
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so* ${JAVA_PATH}/java/linux_x86/libs/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so* ${JAVA_PATH}/native/libs/linux_x86/
|
||||
else
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/
|
||||
cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/
|
||||
fi
|
||||
[ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL}
|
||||
}
|
||||
|
||||
build_jni_arm64() {
|
||||
|
@ -776,7 +805,7 @@ build_jni_x86_64() {
|
|||
mkdir -pv java/jni
|
||||
cd java/jni
|
||||
cmake -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \
|
||||
-DENABLE_VERBOSE=${ENABLE_VERBOSE} "${JAVA_PATH}/native/"
|
||||
-DENABLE_VERBOSE=${ENABLE_VERBOSE} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} "${JAVA_PATH}/native/"
|
||||
make -j$THREAD_NUM
|
||||
if [[ $? -ne 0 ]]; then
|
||||
echo "---------------- mindspore lite: build jni x86_64 failed----------------"
|
||||
|
@ -825,11 +854,16 @@ build_java() {
|
|||
cd ${JAVA_PATH}/java/app/build
|
||||
zip -r mindspore-lite-maven-${VERSION_STR}.zip mindspore
|
||||
|
||||
local inference_or_train=inference
|
||||
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
|
||||
inference_or_train=train
|
||||
fi
|
||||
|
||||
# build linux x86 jar
|
||||
if [[ "$X86_64_SIMD" == "sse" || "$X86_64_SIMD" == "avx" ]]; then
|
||||
local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-inference-linux-x64-${X86_64_SIMD}-jar
|
||||
local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-${X86_64_SIMD}-jar
|
||||
else
|
||||
local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-inference-linux-x64-jar
|
||||
local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-jar
|
||||
fi
|
||||
check_java_home
|
||||
build_lite_java_x86
|
||||
|
@ -843,15 +877,17 @@ build_java() {
|
|||
gradle releaseJar
|
||||
# install and package
|
||||
mkdir -p ${JAVA_PATH}/java/linux_x86/build/lib
|
||||
cp ${JAVA_PATH}/java/linux_x86/libs/*.so ${JAVA_PATH}/java/linux_x86/build/lib/jar
|
||||
cp ${JAVA_PATH}/java/linux_x86/libs/*.so* ${JAVA_PATH}/java/linux_x86/build/lib/jar
|
||||
cd ${JAVA_PATH}/java/linux_x86/build/
|
||||
|
||||
cp -r ${JAVA_PATH}/java/linux_x86/build/lib ${JAVA_PATH}/java/linux_x86/build/${LINUX_X86_PACKAGE_NAME}
|
||||
tar czvf ${LINUX_X86_PACKAGE_NAME}.tar.gz ${LINUX_X86_PACKAGE_NAME}
|
||||
# copy output
|
||||
cp ${JAVA_PATH}/java/app/build/mindspore-lite-maven-${VERSION_STR}.zip ${BASEPATH}/output
|
||||
cp ${LINUX_X86_PACKAGE_NAME}.tar.gz ${BASEPATH}/output
|
||||
|
||||
cd ${BASEPATH}/output
|
||||
[ -n "${VERSION_STR}" ] && rm -rf ${BASEPATH}/mindspore/lite/build/java/mindspore-lite-${VERSION_STR}-inference-linux-x64
|
||||
[ -n "${VERSION_STR}" ] && rm -rf ${BASEPATH}/mindspore/lite/build/java/mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64
|
||||
exit 0
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>com.mindspore.lite.demo</groupId>
|
||||
<artifactId>train_lenet_java</artifactId>
|
||||
<version>1.0</version>
|
||||
|
||||
<properties>
|
||||
<maven.compiler.source>8</maven.compiler.source>
|
||||
<maven.compiler.target>8</maven.compiler.target>
|
||||
</properties>
|
||||
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.mindspore.lite</groupId>
|
||||
<artifactId>mindspore-lite-java</artifactId>
|
||||
<version>1.0</version>
|
||||
<scope>system</scope>
|
||||
<systemPath>${project.basedir}/lib/mindspore-lite-java.jar</systemPath>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
<finalName>${project.name}</finalName>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-assembly-plugin</artifactId>
|
||||
<configuration>
|
||||
<archive>
|
||||
<manifest>
|
||||
<mainClass>com.mindspore.lite.train_lenet.Main</mainClass>
|
||||
</manifest>
|
||||
</archive>
|
||||
<descriptorRefs>
|
||||
<descriptorRef>jar-with-dependencies</descriptorRef>
|
||||
</descriptorRefs>
|
||||
</configuration>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>make-assemble</id>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>single</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
</project>
|
Binary file not shown.
|
@ -0,0 +1,131 @@
|
|||
package com.mindspore.lite.train_lenet;
|
||||
|
||||
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.Vector;
|
||||
|
||||
public class DataSet {
|
||||
private long numOfClasses = 0;
|
||||
private long expectedDataSize = 0;
|
||||
public class DataLabelTuple {
|
||||
public float[] data;
|
||||
public int label;
|
||||
}
|
||||
Vector<DataLabelTuple> trainData;
|
||||
Vector<DataLabelTuple> testData;
|
||||
|
||||
public void initializeMNISTDatabase(String dpath) {
|
||||
numOfClasses = 10;
|
||||
trainData = new Vector<DataLabelTuple>();
|
||||
testData = new Vector<DataLabelTuple>();
|
||||
readMNISTFile(dpath + "/train/train-images-idx3-ubyte", dpath+"/train/train-labels-idx1-ubyte", trainData);
|
||||
readMNISTFile(dpath + "/test/t10k-images-idx3-ubyte", dpath+"/test/t10k-labels-idx1-ubyte", testData);
|
||||
|
||||
System.out.println("train data cnt: " + trainData.size());
|
||||
System.out.println("test data cnt: " + testData.size());
|
||||
}
|
||||
|
||||
private String bytesToHex(byte[] bytes) {
|
||||
StringBuffer sb = new StringBuffer();
|
||||
for (int i = 0; i < bytes.length; i++) {
|
||||
String hex = Integer.toHexString(bytes[i] & 0xFF);
|
||||
if (hex.length() < 2) {
|
||||
sb.append(0);
|
||||
}
|
||||
sb.append(hex);
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
private void readFile(BufferedInputStream inputStream, byte[] bytes, int len) throws IOException {
|
||||
int result = inputStream.read(bytes, 0, len);
|
||||
if (result != len) {
|
||||
System.err.println("expected read " + len + " bytes, but " + result + " read");
|
||||
System.exit(1);
|
||||
}
|
||||
}
|
||||
public void readMNISTFile(String inputFileName, String labelFileName, Vector<DataLabelTuple> dataset) {
|
||||
try {
|
||||
BufferedInputStream ibin = new BufferedInputStream(new FileInputStream(inputFileName));
|
||||
BufferedInputStream lbin = new BufferedInputStream(new FileInputStream(labelFileName));
|
||||
byte[] bytes = new byte[4];
|
||||
|
||||
readFile(ibin, bytes, 4);
|
||||
if (!"00000803".equals(bytesToHex(bytes))) { // 2051
|
||||
System.err.println("The dataset is not valid: " + bytesToHex(bytes));
|
||||
return;
|
||||
}
|
||||
readFile(ibin, bytes, 4);
|
||||
int inumber = Integer.parseInt(bytesToHex(bytes), 16);
|
||||
|
||||
readFile(lbin, bytes, 4);
|
||||
if (!"00000801".equals(bytesToHex(bytes))) { // 2049
|
||||
System.err.println("The dataset label is not valid: " + bytesToHex(bytes));
|
||||
return;
|
||||
}
|
||||
readFile(lbin, bytes, 4);
|
||||
int lnumber = Integer.parseInt(bytesToHex(bytes), 16);
|
||||
if (inumber != lnumber) {
|
||||
System.err.println("input data cnt: " + inumber + " not equal label cnt: " + lnumber);
|
||||
return;
|
||||
}
|
||||
|
||||
// read all labels
|
||||
byte[] labels = new byte[lnumber];
|
||||
readFile(lbin, labels, lnumber);
|
||||
|
||||
// row, column
|
||||
readFile(ibin, bytes, 4);
|
||||
int n_rows = Integer.parseInt(bytesToHex(bytes), 16);
|
||||
readFile(ibin, bytes, 4);
|
||||
int n_cols = Integer.parseInt(bytesToHex(bytes), 16);
|
||||
if (n_rows != 28 || n_cols != 28) {
|
||||
System.err.println("invalid n_rows: " + n_rows + " n_cols: " + n_cols);
|
||||
return;
|
||||
}
|
||||
// read images
|
||||
int image_size = n_rows * n_cols;
|
||||
byte[] image_data = new byte[image_size];
|
||||
for (int i = 0; i < lnumber; i++) {
|
||||
float [] hwc_bin_image = new float[32 * 32];
|
||||
readFile(ibin, image_data, image_size);
|
||||
for (int r = 0; r < 32; r++) {
|
||||
for (int c = 0; c < 32; c++) {
|
||||
int index = r * 32 + c;
|
||||
if (r < 2 || r > 29 || c < 2 || c > 29) {
|
||||
hwc_bin_image[index] = 0;
|
||||
} else {
|
||||
int data = image_data[(r-2)*28 + (c-2)] & 0xff;
|
||||
hwc_bin_image[index] = (float)data / 255.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DataLabelTuple data_label_tupel = new DataLabelTuple();
|
||||
data_label_tupel.data = hwc_bin_image;
|
||||
data_label_tupel.label = labels[i] & 0xff;
|
||||
dataset.add(data_label_tupel);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
System.err.println("Read Dateset exception");
|
||||
}
|
||||
}
|
||||
|
||||
public void setExpectedDataSize(long data_size) {
|
||||
expectedDataSize = data_size;
|
||||
}
|
||||
|
||||
public long getNumOfClasses() {
|
||||
return numOfClasses;
|
||||
}
|
||||
|
||||
public Vector<DataLabelTuple> getTestData() {
|
||||
return testData;
|
||||
}
|
||||
|
||||
public Vector<DataLabelTuple> getTrainData() {
|
||||
return trainData;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
package com.mindspore.lite.train_lenet;
|
||||
|
||||
import com.mindspore.lite.Version;
|
||||
|
||||
public class Main {
|
||||
public static void main(String[] args) {
|
||||
System.out.println(Version.version());
|
||||
if (args.length < 2) {
|
||||
System.err.println("model path and dataset path must be provided.");
|
||||
return;
|
||||
}
|
||||
String modelPath = args[0];
|
||||
String datasetPath = args[1];
|
||||
|
||||
NetRunner net_runner = new NetRunner();
|
||||
net_runner.trainModel(modelPath, datasetPath);
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,220 @@
|
|||
package com.mindspore.lite.train_lenet;
|
||||
|
||||
import com.mindspore.lite.MSTensor;
|
||||
import com.mindspore.lite.TrainSession;
|
||||
import com.mindspore.lite.config.MSConfig;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Vector;
|
||||
|
||||
public class NetRunner {
|
||||
private int dataIndex = 0;
|
||||
private int labelIndex = 1;
|
||||
private TrainSession session;
|
||||
private long batchSize;
|
||||
private long dataSize; // one input data size, in byte
|
||||
private DataSet ds = new DataSet();
|
||||
private long numOfClasses;
|
||||
private long cycles = 3500;
|
||||
private int idx = 1;
|
||||
private String trainedFilePath = "trained.ms";
|
||||
|
||||
public void initAndFigureInputs(String modelPath) {
|
||||
MSConfig msConfig = new MSConfig();
|
||||
// arg 0: DeviceType:DT_CPU -> 0
|
||||
// arg 1: ThreadNum -> 2
|
||||
// arg 2: cpuBindMode:NO_BIND -> 0
|
||||
// arg 3: enable_fp16 -> false
|
||||
msConfig.init(0, 2, 0, false);
|
||||
session = new TrainSession();
|
||||
session.init(modelPath, msConfig);
|
||||
session.setLearningRate(0.01f);
|
||||
|
||||
List<MSTensor> inputs = session.getInputs();
|
||||
if (inputs.size() <= 1) {
|
||||
System.err.println("model input size: " + inputs.size());
|
||||
return;
|
||||
}
|
||||
|
||||
dataIndex = 0;
|
||||
labelIndex = 1;
|
||||
batchSize = inputs.get(dataIndex).getShape()[0];
|
||||
dataSize = inputs.get(dataIndex).size() / batchSize;
|
||||
System.out.println("batch_size: " + batchSize);
|
||||
int index = modelPath.lastIndexOf(".ms");
|
||||
if (index == -1) {
|
||||
System.out.println("The model " + modelPath + " should be named *.ms");
|
||||
return;
|
||||
}
|
||||
trainedFilePath = modelPath.substring(0, index) + "_trained.ms";
|
||||
}
|
||||
|
||||
public int initDB(String datasetPath) {
|
||||
if (dataSize != 0) {
|
||||
ds.setExpectedDataSize(dataSize);
|
||||
}
|
||||
ds.initializeMNISTDatabase(datasetPath);
|
||||
numOfClasses = ds.getNumOfClasses();
|
||||
if (numOfClasses != 10) {
|
||||
System.err.println("unexpected num_of_class: " + numOfClasses);
|
||||
System.exit(1);
|
||||
}
|
||||
|
||||
if (ds.testData.size() == 0) {
|
||||
System.err.println("test data size is 0");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
public float getLoss() {
|
||||
MSTensor tensor = searchOutputsForSize(1);
|
||||
return tensor.getFloatData()[0];
|
||||
}
|
||||
|
||||
private MSTensor searchOutputsForSize(int size) {
|
||||
Map<String, MSTensor> outputs = session.getOutputMapByTensor();
|
||||
for (MSTensor tensor : outputs.values()) {
|
||||
if (tensor.elementsNum() == size) {
|
||||
return tensor;
|
||||
}
|
||||
}
|
||||
System.err.println("can not find output the tensor which element num is " + size);
|
||||
return null;
|
||||
}
|
||||
|
||||
public int trainLoop() {
|
||||
session.train();
|
||||
float min_loss = 1000;
|
||||
float max_acc = 0;
|
||||
for (int i = 0; i < cycles; i++) {
|
||||
fillInputData(ds.getTrainData(), false);
|
||||
session.runGraph();
|
||||
float loss = getLoss();
|
||||
if (min_loss > loss) {
|
||||
min_loss = loss;
|
||||
}
|
||||
if ((i + 1) % 500 == 0) {
|
||||
float acc = calculateAccuracy(10); // only test 10 batch size
|
||||
if (max_acc < acc) {
|
||||
max_acc = acc;
|
||||
}
|
||||
System.out.println("step_" + (i + 1) + ": \tLoss is " + loss + " [min=" + min_loss + "]" + " max_accc=" + max_acc);
|
||||
}
|
||||
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
public float calculateAccuracy(long maxTests) {
|
||||
float accuracy = 0;
|
||||
Vector<DataSet.DataLabelTuple> test_set = ds.getTestData();
|
||||
long tests = test_set.size() / batchSize;
|
||||
if (maxTests != -1 && tests < maxTests) {
|
||||
tests = maxTests;
|
||||
}
|
||||
session.eval();
|
||||
for (long i = 0; i < tests; i++) {
|
||||
Vector<Integer> labels = fillInputData(test_set, (maxTests == -1));
|
||||
if (labels.size() != batchSize) {
|
||||
System.err.println("unexpected labels size: " + labels.size() + " batch_size size: " + batchSize);
|
||||
System.exit(1);
|
||||
}
|
||||
session.runGraph();
|
||||
MSTensor outputsv = searchOutputsForSize((int) (batchSize * numOfClasses));
|
||||
if (outputsv == null) {
|
||||
System.err.println("can not find output tensor with size: " + batchSize * numOfClasses);
|
||||
System.exit(1);
|
||||
}
|
||||
float[] scores = outputsv.getFloatData();
|
||||
for (int b = 0; b < batchSize; b++) {
|
||||
int max_idx = 0;
|
||||
float max_score = scores[(int) (numOfClasses * b)];
|
||||
for (int c = 0; c < numOfClasses; c++) {
|
||||
if (scores[(int) (numOfClasses * b + c)] > max_score) {
|
||||
max_score = scores[(int) (numOfClasses * b + c)];
|
||||
max_idx = c;
|
||||
}
|
||||
|
||||
}
|
||||
if (labels.get(b) == max_idx) {
|
||||
accuracy += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
session.train();
|
||||
accuracy /= (batchSize * tests);
|
||||
return accuracy;
|
||||
}
|
||||
|
||||
// each time fill batch_size data
|
||||
Vector<Integer> fillInputData(Vector<DataSet.DataLabelTuple> dataset, boolean serially) {
|
||||
Vector<Integer> labelsVec = new Vector<Integer>();
|
||||
int totalSize = dataset.size();
|
||||
|
||||
List<MSTensor> inputs = session.getInputs();
|
||||
|
||||
int inputDataCnt = inputs.get(dataIndex).elementsNum();
|
||||
float[] inputBatchData = new float[inputDataCnt];
|
||||
|
||||
int labelDataCnt = inputs.get(labelIndex).elementsNum();
|
||||
int[] labelBatchData = new int[labelDataCnt];
|
||||
|
||||
for (int i = 0; i < batchSize; i++) {
|
||||
if (serially) {
|
||||
idx = (++idx) % totalSize;
|
||||
} else {
|
||||
idx = (int) (Math.random() * totalSize);
|
||||
}
|
||||
|
||||
int label = 0;
|
||||
DataSet.DataLabelTuple dataLabelTuple = dataset.get(idx);
|
||||
label = dataLabelTuple.label;
|
||||
System.arraycopy(dataLabelTuple.data, 0, inputBatchData, (int) (i * dataLabelTuple.data.length), dataLabelTuple.data.length);
|
||||
labelBatchData[i] = label;
|
||||
labelsVec.add(label);
|
||||
}
|
||||
|
||||
ByteBuffer byteBuf = ByteBuffer.allocateDirect(inputBatchData.length * Float.BYTES);
|
||||
byteBuf.order(ByteOrder.nativeOrder());
|
||||
for (int i = 0; i < inputBatchData.length; i++) {
|
||||
byteBuf.putFloat(inputBatchData[i]);
|
||||
}
|
||||
inputs.get(dataIndex).setData(byteBuf);
|
||||
|
||||
ByteBuffer labelByteBuf = ByteBuffer.allocateDirect(labelBatchData.length * 4);
|
||||
labelByteBuf.order(ByteOrder.nativeOrder());
|
||||
for (int i = 0; i < labelBatchData.length; i++) {
|
||||
labelByteBuf.putInt(labelBatchData[i]);
|
||||
}
|
||||
inputs.get(labelIndex).setData(labelByteBuf);
|
||||
|
||||
return labelsVec;
|
||||
}
|
||||
|
||||
public void trainModel(String modelPath, String datasetPath) {
|
||||
System.out.println("==========Loading Model, Create Train Session=============");
|
||||
initAndFigureInputs(modelPath);
|
||||
System.out.println("==========Initing DataSet================");
|
||||
initDB(datasetPath);
|
||||
System.out.println("==========Training Model===================");
|
||||
trainLoop();
|
||||
System.out.println("==========Evaluating The Trained Model============");
|
||||
float acc = calculateAccuracy(-1);
|
||||
System.out.println("accuracy = " + acc);
|
||||
|
||||
if (cycles > 0) {
|
||||
if (session.saveToFile(trainedFilePath)) {
|
||||
System.out.println("Trained model successfully saved: " + trainedFilePath);
|
||||
} else {
|
||||
System.err.println("Save model error.");
|
||||
}
|
||||
}
|
||||
session.free();
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue