forked from mindspore-Ecosystem/mindspore
fix lite java inference bug
This commit is contained in:
parent
57464dd61b
commit
5d5675335b
15
build.sh
15
build.sh
|
@ -526,11 +526,6 @@ build_lite_x86_64_jni_and_jar()
|
|||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
|
||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
|
||||
cp ./libmindspore-lite-jni.so ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/
|
||||
if [[ "X$is_train" = "Xon" ]]; then
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
|
||||
cp ./libmindspore-lite-train-jni.so ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/
|
||||
fi
|
||||
|
||||
# build java common
|
||||
cd ${LITE_JAVA_PATH}/java/common
|
||||
|
@ -715,10 +710,6 @@ build_lite_arm64_and_jni() {
|
|||
fi
|
||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/
|
||||
if [[ "X$is_train" = "Xon" ]]; then
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/
|
||||
fi
|
||||
}
|
||||
|
||||
build_lite_arm32_and_jni() {
|
||||
|
@ -760,10 +751,6 @@ build_lite_arm32_and_jni() {
|
|||
fi
|
||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/
|
||||
if [[ "X$is_train" = "Xon" ]]; then
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/
|
||||
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/
|
||||
fi
|
||||
}
|
||||
|
||||
check_java_home() {
|
||||
|
@ -794,7 +781,7 @@ build_aar() {
|
|||
|
||||
# build java fl_client
|
||||
local is_train=on
|
||||
local train_so=${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/libmindspore-lite-train-jni.so
|
||||
local train_so=${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/libmindspore-lite-train.so
|
||||
if [ ! -f "$train_so" ]; then
|
||||
echo "not exist"
|
||||
is_train=off
|
||||
|
|
|
@ -20,7 +20,7 @@ import com.mindspore.lite.Version;
|
|||
|
||||
public class Main {
|
||||
public static void main(String[] args) {
|
||||
System.loadLibrary("mindspore-lite-train-jni");
|
||||
System.loadLibrary("mindspore-lite-jni");
|
||||
System.out.println(Version.version());
|
||||
if (args.length < 2) {
|
||||
System.err.println("model path and dataset path must be provided.");
|
||||
|
|
|
@ -28,7 +28,6 @@ import com.mindspore.lite.config.MSConfig;
|
|||
public class LiteSession {
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-jni");
|
||||
System.loadLibrary("mindspore-lite-train-jni");
|
||||
}
|
||||
|
||||
private long sessionPtr = 0;
|
||||
|
@ -64,9 +63,9 @@ public class LiteSession {
|
|||
}
|
||||
}
|
||||
|
||||
public static LiteSession createTrainSession(String modelname, final MSConfig config, boolean train_mode) {
|
||||
public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
|
||||
LiteSession liteSession = new LiteSession();
|
||||
liteSession.sessionPtr = liteSession.createTrainSession(modelname, config.getMSConfigPtr(), train_mode, 0);
|
||||
liteSession.sessionPtr = liteSession.createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
|
||||
if (liteSession.sessionPtr == 0) {
|
||||
return null;
|
||||
} else {
|
||||
|
@ -78,8 +77,8 @@ public class LiteSession {
|
|||
return sessionPtr;
|
||||
}
|
||||
|
||||
public void bindThread(boolean if_bind) {
|
||||
this.bindThread(this.sessionPtr, if_bind);
|
||||
public void bindThread(boolean ifBind) {
|
||||
this.bindThread(this.sessionPtr, ifBind);
|
||||
}
|
||||
|
||||
public boolean compileGraph(Model model) {
|
||||
|
@ -156,8 +155,8 @@ public class LiteSession {
|
|||
return this.resize(this.sessionPtr, inputsArray, dims);
|
||||
}
|
||||
|
||||
public boolean export(String modelFilename, int model_type, int quantization_type) {
|
||||
return this.export(this.sessionPtr, modelFilename, model_type, quantization_type);
|
||||
public boolean export(String modelFileName, int modelType, int quantizationType) {
|
||||
return this.export(this.sessionPtr, modelFileName, modelType, quantizationType);
|
||||
}
|
||||
|
||||
public boolean train() {
|
||||
|
@ -205,11 +204,11 @@ public class LiteSession {
|
|||
|
||||
private native long createSessionWithModel(MappedByteBuffer buffer, long msConfigPtr);
|
||||
|
||||
private native long createTrainSession(String filename, long msContextPtr, boolean train_mode, long msTrainCfgPtr);
|
||||
private native long createTrainSession(String filename, long msContextPtr, boolean trainMode, long msTrainCfgPtr);
|
||||
|
||||
private native boolean compileGraph(long sessionPtr, long modelPtr);
|
||||
|
||||
private native void bindThread(long sessionPtr, boolean if_bind);
|
||||
private native void bindThread(long sessionPtr, boolean ifBind);
|
||||
|
||||
private native boolean runGraph(long sessionPtr);
|
||||
|
||||
|
@ -229,7 +228,7 @@ public class LiteSession {
|
|||
|
||||
private native boolean resize(long sessionPtr, long[] inputs, int[][] dims);
|
||||
|
||||
private native boolean export(long sessionPtr, String modelFilename, int model_type, int quantization_type);
|
||||
private native boolean export(long sessionPtr, String modelFileName, int modelType, int quantizationType);
|
||||
|
||||
private native boolean train(long sessionPtr);
|
||||
|
||||
|
@ -239,7 +238,7 @@ public class LiteSession {
|
|||
|
||||
private native boolean isEval(long sessionPtr);
|
||||
|
||||
private native boolean setLearningRate(long sessionPtr, float learning_rate);
|
||||
private native boolean setLearningRate(long sessionPtr, float learningRate);
|
||||
|
||||
private native boolean setupVirtualBatch(long sessionPtr, int virtualBatchMultiplier, float learningRate, float momentum);
|
||||
|
||||
|
|
|
@ -28,7 +28,6 @@ import com.mindspore.lite.config.MSConfig;
|
|||
public class LiteSession {
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-jni");
|
||||
System.loadLibrary("mindspore-lite-train-jni");
|
||||
}
|
||||
|
||||
private long sessionPtr = 0;
|
||||
|
@ -64,9 +63,9 @@ public class LiteSession {
|
|||
}
|
||||
}
|
||||
|
||||
public static LiteSession createTrainSession(String modelname, final MSConfig config, boolean train_mode) {
|
||||
public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
|
||||
LiteSession liteSession = new LiteSession();
|
||||
liteSession.sessionPtr = liteSession.createTrainSession(modelname, config.getMSConfigPtr(), train_mode, 0);
|
||||
liteSession.sessionPtr = liteSession.createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
|
||||
if (liteSession.sessionPtr == 0) {
|
||||
return null;
|
||||
} else {
|
||||
|
@ -78,8 +77,8 @@ public class LiteSession {
|
|||
return sessionPtr;
|
||||
}
|
||||
|
||||
public void bindThread(boolean if_bind) {
|
||||
this.bindThread(this.sessionPtr, if_bind);
|
||||
public void bindThread(boolean ifBind) {
|
||||
this.bindThread(this.sessionPtr, ifBind);
|
||||
}
|
||||
|
||||
public boolean compileGraph(Model model) {
|
||||
|
@ -156,8 +155,8 @@ public class LiteSession {
|
|||
return this.resize(this.sessionPtr, inputsArray, dims);
|
||||
}
|
||||
|
||||
public boolean export(String modelFilename, int model_type, int quantization_type) {
|
||||
return this.export(this.sessionPtr, modelFilename, model_type, quantization_type);
|
||||
public boolean export(String modelFileName, int modelType, int quantizationType) {
|
||||
return this.export(this.sessionPtr, modelFileName, modelType, quantizationType);
|
||||
}
|
||||
|
||||
public boolean train() {
|
||||
|
@ -176,8 +175,8 @@ public class LiteSession {
|
|||
return this.isEval(this.sessionPtr);
|
||||
}
|
||||
|
||||
public boolean setLearningRate(float learning_rate) {
|
||||
return this.setLearningRate(this.sessionPtr, learning_rate);
|
||||
public boolean setLearningRate(float learningRate) {
|
||||
return this.setLearningRate(this.sessionPtr, learningRate);
|
||||
}
|
||||
|
||||
public boolean setupVirtualBatch(int virtualBatchMultiplier, float learningRate, float momentum) {
|
||||
|
@ -205,11 +204,11 @@ public class LiteSession {
|
|||
|
||||
private native long createSessionWithModel(MappedByteBuffer buffer, long msConfigPtr);
|
||||
|
||||
private native long createTrainSession(String filename, long msContextPtr, boolean train_mode, long msTrainCfgPtr);
|
||||
private native long createTrainSession(String fileName, long msContextPtr, boolean trainMode, long msTrainCfgPtr);
|
||||
|
||||
private native boolean compileGraph(long sessionPtr, long modelPtr);
|
||||
|
||||
private native void bindThread(long sessionPtr, boolean if_bind);
|
||||
private native void bindThread(long sessionPtr, boolean ifBind);
|
||||
|
||||
private native boolean runGraph(long sessionPtr);
|
||||
|
||||
|
@ -229,7 +228,7 @@ public class LiteSession {
|
|||
|
||||
private native boolean resize(long sessionPtr, long[] inputs, int[][] dims);
|
||||
|
||||
private native boolean export(long sessionPtr, String modelFilename, int model_type, int quantization_type);
|
||||
private native boolean export(long sessionPtr, String modelFilename, int modelType, int quantizationType);
|
||||
|
||||
private native boolean train(long sessionPtr);
|
||||
|
||||
|
@ -239,7 +238,7 @@ public class LiteSession {
|
|||
|
||||
private native boolean isEval(long sessionPtr);
|
||||
|
||||
private native boolean setLearningRate(long sessionPtr, float learning_rate);
|
||||
private native boolean setLearningRate(long sessionPtr, float learningRate);
|
||||
|
||||
private native boolean setupVirtualBatch(long sessionPtr, int virtualBatchMultiplier, float learningRate, float momentum);
|
||||
|
||||
|
|
|
@ -103,13 +103,11 @@ endif()
|
|||
|
||||
if(SUPPORT_TRAIN)
|
||||
set(LITE_TRAIN_SO_NAME mindspore-lite-train minddata-lite)
|
||||
set(JNI_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp)
|
||||
add_library(mindspore-lite-train-jni SHARED ${JNI_TRAIN_SRC})
|
||||
if(PLATFORM_ARM64 OR PLATFORM_ARM32)
|
||||
find_library(log-lib log)
|
||||
target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME} ${log-lib})
|
||||
target_link_libraries(mindspore-lite-jni ${LITE_TRAIN_SO_NAME} ${log-lib})
|
||||
else()
|
||||
target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME})
|
||||
target_link_libraries(mindspore-lite-jni ${LITE_TRAIN_SO_NAME})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
|
|
@ -356,8 +356,8 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_setLea
|
|||
|
||||
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_setupVirtualBatch(JNIEnv *env, jobject thiz,
|
||||
jlong session_ptr,
|
||||
jint virtualBatchMultiplier,
|
||||
jfloat learningRate,
|
||||
jint virtual_batch_factor,
|
||||
jfloat learning_rate,
|
||||
jfloat momentum) {
|
||||
auto *session_pointer = reinterpret_cast<void *>(session_ptr);
|
||||
if (session_pointer == nullptr) {
|
||||
|
@ -365,7 +365,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_setupV
|
|||
return (jboolean) false;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
|
||||
auto ret = lite_session_ptr->SetupVirtualBatch(virtualBatchMultiplier, learningRate, momentum);
|
||||
auto ret = lite_session_ptr->SetupVirtualBatch(virtual_batch_factor, learning_rate, momentum);
|
||||
return (jboolean)(ret == mindspore::lite::RET_OK);
|
||||
}
|
||||
|
||||
|
@ -411,3 +411,24 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getFeat
|
|||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createTrainSession(JNIEnv *env, jobject thiz,
|
||||
jstring file_name,
|
||||
jlong ms_context_ptr,
|
||||
jboolean train_mode,
|
||||
jlong train_config_ptr) {
|
||||
auto *pointer = reinterpret_cast<void *>(ms_context_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Context pointer from java is nullptr");
|
||||
return jlong(nullptr);
|
||||
}
|
||||
auto *lite_context_ptr = static_cast<mindspore::lite::Context *>(pointer);
|
||||
|
||||
auto session = mindspore::session::LiteSession::CreateTrainSession(env->GetStringUTFChars(file_name, JNI_FALSE),
|
||||
lite_context_ptr, train_mode, nullptr);
|
||||
if (session == nullptr) {
|
||||
MS_LOGE("CreateTrainSession failed");
|
||||
return jlong(nullptr);
|
||||
}
|
||||
return jlong(session);
|
||||
}
|
||||
|
|
|
@ -1,39 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <jni.h>
|
||||
#include "common/ms_log.h"
|
||||
#include "include/lite_session.h"
|
||||
#include "include/train/train_cfg.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createTrainSession(
|
||||
JNIEnv *env, jobject thiz, jstring file_name, jlong ms_context_ptr, jboolean train_mode, jlong train_config_ptr) {
|
||||
auto *pointer = reinterpret_cast<void *>(ms_context_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Context pointer from java is nullptr");
|
||||
return jlong(nullptr);
|
||||
}
|
||||
auto *lite_context_ptr = static_cast<mindspore::lite::Context *>(pointer);
|
||||
|
||||
auto session = mindspore::session::LiteSession::CreateTrainSession(env->GetStringUTFChars(file_name, JNI_FALSE),
|
||||
lite_context_ptr, train_mode, nullptr);
|
||||
if (session == nullptr) {
|
||||
MS_LOGE("CreateSession failed");
|
||||
return jlong(nullptr);
|
||||
}
|
||||
return jlong(session);
|
||||
}
|
Loading…
Reference in New Issue