modify load native lib and Benchmark switch new api
This commit is contained in:
parent
efe40a1ffc
commit
1fa35a3b0a
|
@ -1,5 +1,5 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -82,15 +82,24 @@ build_lite_x86_64_jni_and_jar() {
|
|||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
|
||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
|
||||
cp ./libmindspore-lite-jni.so ${INSTALL_PREFIX}/${pkg_name}/runtime/lib/
|
||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/linux_x86_64/
|
||||
cp ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/libmindspore-lite.so ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/linux_x86_64/
|
||||
if [ -f "${BASEPATH}/output/tmp/${pkg_name}/runtime/third_party/glog/libglog.so.0" ]; then
|
||||
cp ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/*.so ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/linux_x86_64/
|
||||
LIB_GLOG="libglog.so*"
|
||||
if [ -f "`echo ${BASEPATH}/output/tmp/${pkg_name}/runtime/third_party/glog/${LIB_GLOG}`" ]; then
|
||||
cp ${BASEPATH}/output/tmp/${pkg_name}/runtime/third_party/glog/libglog.so* ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/linux_x86_64/libglog.so
|
||||
fi
|
||||
LIB_JPEG="libjpeg.so*"
|
||||
if [ -f "`echo ${BASEPATH}/output/tmp/${pkg_name}/runtime/third_party/libjpeg-turbo/lib/${LIB_JPEG}`" ]; then
|
||||
cp ${BASEPATH}/output/tmp/${pkg_name}/runtime/third_party/libjpeg-turbo/lib/${LIB_JPEG} ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/linux_x86_64/
|
||||
fi
|
||||
LIB_TURBOJPEG="libturbojpeg.so*"
|
||||
if [ -f "`echo ${BASEPATH}/output/tmp/${pkg_name}/runtime/third_party/libjpeg-turbo/lib/${LIB_TURBOJPEG}`" ]; then
|
||||
cp ${BASEPATH}/output/tmp/${pkg_name}/runtime/third_party/libjpeg-turbo/lib/${LIB_TURBOJPEG} ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/linux_x86_64/
|
||||
fi
|
||||
if [[ "X$is_train" = "Xon" ]]; then
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
|
||||
cp ./libmindspore-lite-train-jni.so ${INSTALL_PREFIX}/${pkg_name}/runtime/lib/
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/linux_x86_64/
|
||||
fi
|
||||
|
||||
cd ${LITE_JAVA_PATH}/java
|
||||
|
@ -181,14 +190,25 @@ build_lite_aarch64_jni_and_jar() {
|
|||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/linux_aarch64/
|
||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/linux_aarch64/
|
||||
cp ./libmindspore-lite-jni.so ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/
|
||||
cp ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/libmindspore-lite.so ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/linux_aarch64/
|
||||
if [ -f "${BASEPATH}/output/tmp/${pkg_name}/runtime/third_party/glog/libglog.so.0" ]; then
|
||||
|
||||
cp ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/*.so ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/linux_aarch64/
|
||||
LIB_GLOG="libglog.so*"
|
||||
if [ -f "`echo ${BASEPATH}/output/tmp/${pkg_name}/runtime/third_party/glog/${LIB_GLOG}`" ]; then
|
||||
cp ${BASEPATH}/output/tmp/${pkg_name}/runtime/third_party/glog/libglog.so* ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/linux_aarch64/libglog.so
|
||||
fi
|
||||
LIB_JPEG="libjpeg.so*"
|
||||
if [ -f "`echo ${BASEPATH}/output/tmp/${pkg_name}/runtime/third_party/libjpeg-turbo/lib/${LIB_JPEG}`" ]; then
|
||||
cp ${BASEPATH}/output/tmp/${pkg_name}/runtime/third_party/libjpeg-turbo/lib/${LIB_JPEG} ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/linux_aarch64/
|
||||
fi
|
||||
LIB_TURBOJPEG="libturbojpeg.so*"
|
||||
if [ -f "`echo ${BASEPATH}/output/tmp/${pkg_name}/runtime/third_party/libjpeg-turbo/lib/${LIB_TURBOJPEG}`" ]; then
|
||||
cp ${BASEPATH}/output/tmp/${pkg_name}/runtime/third_party/libjpeg-turbo/lib/${LIB_TURBOJPEG} ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/linux_aarch64/
|
||||
fi
|
||||
if [[ "X$is_train" = "Xon" ]]; then
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/linux_aarch64/libs/
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/linux_aarch64/
|
||||
cp ./libmindspore-lite-train-jni.so ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/linux_aarch64/
|
||||
fi
|
||||
|
||||
cd ${LITE_JAVA_PATH}/java
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,17 +16,11 @@
|
|||
|
||||
package com.mindspore;
|
||||
|
||||
import com.mindspore.lite.NativeLibrary;
|
||||
import com.mindspore.config.MindsporeLite;
|
||||
|
||||
public class Graph {
|
||||
static {
|
||||
try {
|
||||
NativeLibrary.load();
|
||||
} catch (Exception e) {
|
||||
System.err.println("Failed to load MindSporLite native library.");
|
||||
e.printStackTrace();
|
||||
throw e;
|
||||
}
|
||||
MindsporeLite.init();
|
||||
}
|
||||
|
||||
private long graphPtr;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,19 +16,13 @@
|
|||
|
||||
package com.mindspore;
|
||||
|
||||
import com.mindspore.lite.NativeLibrary;
|
||||
import com.mindspore.config.MindsporeLite;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
public class MSTensor {
|
||||
static {
|
||||
try {
|
||||
NativeLibrary.load();
|
||||
} catch (Exception e) {
|
||||
System.err.println("Failed to load MindSporLite native library.");
|
||||
e.printStackTrace();
|
||||
throw e;
|
||||
}
|
||||
MindsporeLite.init();
|
||||
}
|
||||
|
||||
private long tensorPtr;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,8 +17,8 @@
|
|||
package com.mindspore;
|
||||
|
||||
import com.mindspore.config.MSContext;
|
||||
import com.mindspore.config.MindsporeLite;
|
||||
import com.mindspore.config.TrainCfg;
|
||||
import com.mindspore.lite.NativeLibrary;
|
||||
|
||||
import java.nio.MappedByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
|
@ -26,13 +26,7 @@ import java.util.List;
|
|||
|
||||
public class Model {
|
||||
static {
|
||||
try {
|
||||
NativeLibrary.load();
|
||||
} catch (Exception e) {
|
||||
System.err.println("Failed to load MindSporLite native library.");
|
||||
e.printStackTrace();
|
||||
throw e;
|
||||
}
|
||||
MindsporeLite.init();
|
||||
}
|
||||
|
||||
private long modelPtr = 0;
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
|
||||
package com.mindspore;
|
||||
|
||||
import com.mindspore.config.MindsporeLite;
|
||||
import com.mindspore.config.RunnerConfig;
|
||||
import com.mindspore.lite.NativeLibrary;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
@ -29,13 +29,7 @@ import java.util.List;
|
|||
*/
|
||||
public class ModelParallelRunner {
|
||||
static {
|
||||
try {
|
||||
NativeLibrary.load();
|
||||
} catch (Exception e) {
|
||||
System.err.println("Failed to load MindSporLite native library.");
|
||||
e.printStackTrace();
|
||||
throw e;
|
||||
}
|
||||
MindsporeLite.init();
|
||||
}
|
||||
|
||||
private long modelParallelRunnerPtr;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,17 +16,9 @@
|
|||
|
||||
package com.mindspore.config;
|
||||
|
||||
import com.mindspore.lite.NativeLibrary;
|
||||
|
||||
public class MSContext {
|
||||
static {
|
||||
try {
|
||||
NativeLibrary.load();
|
||||
} catch (Exception e) {
|
||||
System.err.println("Failed to load MindSporLite native library.");
|
||||
e.printStackTrace();
|
||||
throw e;
|
||||
}
|
||||
MindsporeLite.init();
|
||||
}
|
||||
|
||||
private long msContextPtr;
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
package com.mindspore.config;
|
||||
|
||||
import java.util.logging.Logger;
|
||||
|
||||
public final class MindsporeLite {
|
||||
private static final Logger LOGGER = Logger.getLogger(MindsporeLite.class.toString());
|
||||
|
||||
/**
|
||||
* Init function.
|
||||
*/
|
||||
public static void init() {
|
||||
LOGGER.info("MindsporeLite init load ...");
|
||||
try {
|
||||
NativeLibrary.load();
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe("Failed to load MindSporLite native library.");
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
static {
|
||||
LOGGER.info("MindsporeLite init ...");
|
||||
init();
|
||||
}
|
||||
}
|
|
@ -1,7 +1,9 @@
|
|||
package com.mindspore.lite;
|
||||
package com.mindspore.config;
|
||||
|
||||
import com.mindspore.config.Version;
|
||||
import java.io.*;
|
||||
import java.io.File;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
public class NativeLibrary {
|
||||
|
@ -11,6 +13,9 @@ public class NativeLibrary {
|
|||
private static final String MINDSPORE_LITE_LIBNAME = "mindspore-lite";
|
||||
private static final String MINDSPORE_LITE_JNI_LIBNAME = "mindspore-lite-jni";
|
||||
|
||||
/**
|
||||
* Load function.
|
||||
*/
|
||||
public static void load() {
|
||||
if (isLibLoaded() || loadLibrary()) {
|
||||
LOGGER.info("Native lib has been loaded.");
|
||||
|
@ -19,6 +24,9 @@ public class NativeLibrary {
|
|||
loadLibs();
|
||||
}
|
||||
|
||||
/**
|
||||
* Load native libs function.
|
||||
*/
|
||||
public static void loadLibs() {
|
||||
loadLib(makeResourceName("lib" + GLOG_LIBNAME + ".so"));
|
||||
loadLib(makeResourceName("lib" + MINDSPORE_LITE_LIBNAME + ".so"));
|
||||
|
@ -34,28 +42,31 @@ public class NativeLibrary {
|
|||
return true;
|
||||
}
|
||||
|
||||
public static boolean loadLibrary() {
|
||||
/**
|
||||
* Load library function.
|
||||
* If any jni lib is loaded successfully, the function return True.
|
||||
* jni lib: mindspore-lite-jni, mindspore-lite-train-jni
|
||||
*/
|
||||
private static boolean loadLibrary() {
|
||||
boolean loadSuccess = false;
|
||||
try {
|
||||
System.loadLibrary(GLOG_LIBNAME);
|
||||
LOGGER.info("loadLibrary " + GLOG_LIBNAME + ": success");
|
||||
loadSuccess = true;
|
||||
} catch (UnsatisfiedLinkError e) {
|
||||
LOGGER.info("tryLoadLibrary " + GLOG_LIBNAME + " failed: " + e.getMessage());
|
||||
LOGGER.info("tryLoadLibrary " + GLOG_LIBNAME + " failed.");
|
||||
}
|
||||
try {
|
||||
System.loadLibrary(MINDSPORE_LITE_LIBNAME);
|
||||
LOGGER.info("loadLibrary " + MINDSPORE_LITE_LIBNAME + ": success");
|
||||
loadSuccess = true;
|
||||
} catch (UnsatisfiedLinkError e) {
|
||||
LOGGER.info("tryLoadLibrary " + MINDSPORE_LITE_LIBNAME + " failed: " + e.getMessage());
|
||||
LOGGER.info("tryLoadLibrary " + MINDSPORE_LITE_LIBNAME + " failed.");
|
||||
}
|
||||
try {
|
||||
System.loadLibrary(MINDSPORE_LITE_JNI_LIBNAME);
|
||||
LOGGER.info("loadLibrary " + MINDSPORE_LITE_JNI_LIBNAME + ": success");
|
||||
loadSuccess = true;
|
||||
LOGGER.info("loadLibrary " + MINDSPORE_LITE_JNI_LIBNAME + ": success");
|
||||
} catch (UnsatisfiedLinkError e) {
|
||||
LOGGER.info("tryLoadLibrary " + MINDSPORE_LITE_JNI_LIBNAME + " failed: " + e.getMessage());
|
||||
LOGGER.info("tryLoadLibrary " + MINDSPORE_LITE_JNI_LIBNAME + " failed.");
|
||||
}
|
||||
return loadSuccess;
|
||||
}
|
|
@ -16,8 +16,6 @@
|
|||
|
||||
package com.mindspore.config;
|
||||
|
||||
import com.mindspore.lite.NativeLibrary;
|
||||
|
||||
/**
|
||||
* Configuration for ModelParallelRunner.
|
||||
*
|
||||
|
@ -25,13 +23,7 @@ import com.mindspore.lite.NativeLibrary;
|
|||
*/
|
||||
public class RunnerConfig {
|
||||
static {
|
||||
try {
|
||||
NativeLibrary.load();
|
||||
} catch (Exception e) {
|
||||
System.err.println("Failed to load MindSporLite native library.");
|
||||
e.printStackTrace();
|
||||
throw e;
|
||||
}
|
||||
MindsporeLite.init();
|
||||
}
|
||||
|
||||
private long runnerConfigPtr;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,7 +17,6 @@
|
|||
package com.mindspore.config;
|
||||
|
||||
public class TrainCfg {
|
||||
// depend "mindspore-lite-train-jni"
|
||||
static {
|
||||
try {
|
||||
System.loadLibrary("mindspore-lite-train-jni");
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package com.mindspore.config;
|
||||
|
||||
import com.mindspore.lite.NativeLibrary;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* Define mindspore version info.
|
||||
|
@ -24,14 +24,21 @@ import com.mindspore.lite.NativeLibrary;
|
|||
* @since v1.0
|
||||
*/
|
||||
public class Version {
|
||||
private static final Logger LOGGER = Logger.getLogger(Version.class.toString());
|
||||
static {
|
||||
LOGGER.info("Version init ...");
|
||||
init();
|
||||
}
|
||||
|
||||
/**
|
||||
* Init function.
|
||||
*/
|
||||
public static void init() {
|
||||
LOGGER.info("Version init load ...");
|
||||
try {
|
||||
if (!NativeLibrary.loadLibrary()) {
|
||||
NativeLibrary.loadLibs();
|
||||
}
|
||||
NativeLibrary.load();
|
||||
} catch (Exception e) {
|
||||
System.err.println("Failed to load MindSporLite native library.");
|
||||
e.printStackTrace();
|
||||
LOGGER.severe("Failed to load MindSporLite native library.");
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package com.mindspore;
|
||||
|
||||
import com.mindspore.config.*;
|
||||
import com.mindspore.lite.NativeLibrary;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
@ -32,23 +31,13 @@ import java.util.ArrayList;
|
|||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
@RunWith(JUnit4.class)
|
||||
public class ModelTest {
|
||||
|
||||
@Test
|
||||
public void testBuildByGraphSuccess() {
|
||||
try {
|
||||
NativeLibrary.load();
|
||||
System.err.println("System: NativeLibrary load success.");
|
||||
} catch (Exception e) {
|
||||
System.err.println("Failed to load MindSporLite native library.");
|
||||
e.printStackTrace();
|
||||
}
|
||||
System.out.println(Version.version());
|
||||
Graph g = new Graph();
|
||||
assertTrue(g.load("../test/ut/src/runtime/kernel/arm/test_data/nets/lenet_train.ms"));
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,18 +14,16 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import com.mindspore.lite.DataType;
|
||||
import com.mindspore.lite.LiteSession;
|
||||
import com.mindspore.lite.MSTensor;
|
||||
import com.mindspore.lite.Model;
|
||||
import com.mindspore.lite.config.DeviceType;
|
||||
import com.mindspore.lite.config.MSConfig;
|
||||
import com.mindspore.MSTensor;
|
||||
import com.mindspore.config.DeviceType;
|
||||
import com.mindspore.config.MSContext;
|
||||
import com.mindspore.config.DataType;
|
||||
import com.mindspore.Model;
|
||||
|
||||
import java.io.*;
|
||||
|
||||
public class Benchmark {
|
||||
private static Model model;
|
||||
private static LiteSession session;
|
||||
|
||||
public static byte[] readBinFile(String fileName, int size) {
|
||||
try {
|
||||
|
@ -58,7 +56,7 @@ public class Benchmark {
|
|||
String[] strings = lineContent.split(" ");
|
||||
if (line++ % 2 == 0) {
|
||||
name = strings[0];
|
||||
outTensor = session.getOutputByTensorName(name);
|
||||
outTensor = model.getOutputByTensorName(name);
|
||||
continue;
|
||||
}
|
||||
float[] benchmarkData = new float[strings.length];
|
||||
|
@ -103,41 +101,7 @@ public class Benchmark {
|
|||
}
|
||||
return meanError < accuracy;
|
||||
}
|
||||
|
||||
private static boolean compile() {
|
||||
MSConfig msConfig = new MSConfig();
|
||||
boolean ret = msConfig.init(DeviceType.DT_CPU, 2);
|
||||
if (!ret) {
|
||||
System.err.println("Init context failed");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Create the MindSpore lite session.
|
||||
session = new LiteSession();
|
||||
ret = session.init(msConfig);
|
||||
msConfig.free();
|
||||
if (!ret) {
|
||||
System.err.println("Create session failed");
|
||||
model.free();
|
||||
return false;
|
||||
}
|
||||
|
||||
// Compile graph.
|
||||
ret = session.compileGraph(model);
|
||||
if (!ret) {
|
||||
System.err.println("Compile graph failed");
|
||||
model.free();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
private static void freeBuffer() {
|
||||
session.free();
|
||||
model.free();
|
||||
}
|
||||
|
||||
|
||||
public static void main(String[] args) {
|
||||
if (args.length < 4) {
|
||||
System.err.println("We must pass parameters such as modelPath, inDataFile, benchmarkDataFile and accuracy.");
|
||||
|
@ -149,40 +113,42 @@ public class Benchmark {
|
|||
String benchmarkDataFile = args[2];
|
||||
float accuracy = Float.parseFloat(args[3]);
|
||||
|
||||
MSContext context = new MSContext();
|
||||
context.init(1, 0);
|
||||
boolean ret = context.addDeviceInfo(DeviceType.DT_CPU, false, 0);
|
||||
if (!ret) {
|
||||
System.err.println("Compile graph failed");
|
||||
return;
|
||||
}
|
||||
model = new Model();
|
||||
|
||||
|
||||
boolean ret = model.loadModel(modelPath);
|
||||
ret = model.build(modelPath, 0, context);
|
||||
if (!ret) {
|
||||
System.err.println("Load model failed, model path is " + modelPath);
|
||||
System.err.println("Compile graph failed, model path is " + modelPath);
|
||||
model.free();
|
||||
return;
|
||||
}
|
||||
ret = compile();
|
||||
if (!ret) {
|
||||
System.err.println("MindSpore Lite compile failed.");
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < session.getInputs().size(); i++) {
|
||||
MSTensor inputTensor = session.getInputs().get(i);
|
||||
if (inputTensor.getDataType() != DataType.kNumberTypeFloat32) {
|
||||
System.err.println("Input tensor data type is not float, the data type is " + inputTensor.getDataType());
|
||||
freeBuffer();
|
||||
for (int index = 0; index < model.getInputs().size(); index++) {
|
||||
MSTensor msTensor = model.getInputs().get(index);
|
||||
if (msTensor.getDataType() != DataType.kNumberTypeFloat32) {
|
||||
System.err.println("Input tensor data type is not float, the data type is " + msTensor.getDataType());
|
||||
model.free();
|
||||
return;
|
||||
}
|
||||
// Set Input Data.
|
||||
byte[] data = readBinFile(inDataFile[i], (int) inputTensor.size());
|
||||
inputTensor.setData(data);
|
||||
byte[] data = readBinFile(inDataFile[index], (int) msTensor.size());
|
||||
msTensor.setData(data);
|
||||
}
|
||||
|
||||
// Run Inference.
|
||||
if (!session.runGraph()) {
|
||||
ret = model.predict();
|
||||
if (!ret) {
|
||||
System.err.println("MindSpore Lite run failed.");
|
||||
freeBuffer();
|
||||
model.free();
|
||||
return;
|
||||
}
|
||||
|
||||
boolean benchmarkResult = compareData(benchmarkDataFile, accuracy);
|
||||
freeBuffer();
|
||||
model.free();
|
||||
if (!benchmarkResult) {
|
||||
System.err.println(modelPath + " accuracy error is too large.");
|
||||
System.exit(1);
|
||||
|
|
Loading…
Reference in New Issue