fix lite java inference bug

This commit is contained in:
zhengjun10 2021-07-03 16:18:05 +08:00
parent 57464dd61b
commit 5d5675335b
7 changed files with 50 additions and 85 deletions

View File

@ -526,11 +526,6 @@ build_lite_x86_64_jni_and_jar()
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
cp ./libmindspore-lite-jni.so ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/
if [[ "X$is_train" = "Xon" ]]; then
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
cp ./libmindspore-lite-train-jni.so ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/
fi
# build java common
cd ${LITE_JAVA_PATH}/java/common
@ -715,10 +710,6 @@ build_lite_arm64_and_jni() {
fi
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/
if [[ "X$is_train" = "Xon" ]]; then
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/
fi
}
build_lite_arm32_and_jni() {
@ -760,10 +751,6 @@ build_lite_arm32_and_jni() {
fi
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/
if [[ "X$is_train" = "Xon" ]]; then
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/
fi
}
check_java_home() {
@ -794,7 +781,7 @@ build_aar() {
# build java fl_client
local is_train=on
local train_so=${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/libmindspore-lite-train-jni.so
local train_so=${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/libmindspore-lite-train.so
if [ ! -f "$train_so" ]; then
echo "not exist"
is_train=off

View File

@ -20,7 +20,7 @@ import com.mindspore.lite.Version;
public class Main {
public static void main(String[] args) {
System.loadLibrary("mindspore-lite-train-jni");
System.loadLibrary("mindspore-lite-jni");
System.out.println(Version.version());
if (args.length < 2) {
System.err.println("model path and dataset path must be provided.");

View File

@ -28,7 +28,6 @@ import com.mindspore.lite.config.MSConfig;
public class LiteSession {
static {
System.loadLibrary("mindspore-lite-jni");
System.loadLibrary("mindspore-lite-train-jni");
}
private long sessionPtr = 0;
@ -64,9 +63,9 @@ public class LiteSession {
}
}
public static LiteSession createTrainSession(String modelname, final MSConfig config, boolean train_mode) {
public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
LiteSession liteSession = new LiteSession();
liteSession.sessionPtr = liteSession.createTrainSession(modelname, config.getMSConfigPtr(), train_mode, 0);
liteSession.sessionPtr = liteSession.createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
if (liteSession.sessionPtr == 0) {
return null;
} else {
@ -78,8 +77,8 @@ public class LiteSession {
return sessionPtr;
}
public void bindThread(boolean if_bind) {
this.bindThread(this.sessionPtr, if_bind);
public void bindThread(boolean ifBind) {
this.bindThread(this.sessionPtr, ifBind);
}
public boolean compileGraph(Model model) {
@ -156,8 +155,8 @@ public class LiteSession {
return this.resize(this.sessionPtr, inputsArray, dims);
}
public boolean export(String modelFilename, int model_type, int quantization_type) {
return this.export(this.sessionPtr, modelFilename, model_type, quantization_type);
public boolean export(String modelFileName, int modelType, int quantizationType) {
return this.export(this.sessionPtr, modelFileName, modelType, quantizationType);
}
public boolean train() {
@ -205,11 +204,11 @@ public class LiteSession {
private native long createSessionWithModel(MappedByteBuffer buffer, long msConfigPtr);
private native long createTrainSession(String filename, long msContextPtr, boolean train_mode, long msTrainCfgPtr);
private native long createTrainSession(String filename, long msContextPtr, boolean trainMode, long msTrainCfgPtr);
private native boolean compileGraph(long sessionPtr, long modelPtr);
private native void bindThread(long sessionPtr, boolean if_bind);
private native void bindThread(long sessionPtr, boolean ifBind);
private native boolean runGraph(long sessionPtr);
@ -229,7 +228,7 @@ public class LiteSession {
private native boolean resize(long sessionPtr, long[] inputs, int[][] dims);
private native boolean export(long sessionPtr, String modelFilename, int model_type, int quantization_type);
private native boolean export(long sessionPtr, String modelFileName, int modelType, int quantizationType);
private native boolean train(long sessionPtr);
@ -239,7 +238,7 @@ public class LiteSession {
private native boolean isEval(long sessionPtr);
private native boolean setLearningRate(long sessionPtr, float learning_rate);
private native boolean setLearningRate(long sessionPtr, float learningRate);
private native boolean setupVirtualBatch(long sessionPtr, int virtualBatchMultiplier, float learningRate, float momentum);

View File

@ -28,7 +28,6 @@ import com.mindspore.lite.config.MSConfig;
public class LiteSession {
static {
System.loadLibrary("mindspore-lite-jni");
System.loadLibrary("mindspore-lite-train-jni");
}
private long sessionPtr = 0;
@ -64,9 +63,9 @@ public class LiteSession {
}
}
public static LiteSession createTrainSession(String modelname, final MSConfig config, boolean train_mode) {
public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
LiteSession liteSession = new LiteSession();
liteSession.sessionPtr = liteSession.createTrainSession(modelname, config.getMSConfigPtr(), train_mode, 0);
liteSession.sessionPtr = liteSession.createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
if (liteSession.sessionPtr == 0) {
return null;
} else {
@ -78,8 +77,8 @@ public class LiteSession {
return sessionPtr;
}
public void bindThread(boolean if_bind) {
this.bindThread(this.sessionPtr, if_bind);
public void bindThread(boolean ifBind) {
this.bindThread(this.sessionPtr, ifBind);
}
public boolean compileGraph(Model model) {
@ -156,8 +155,8 @@ public class LiteSession {
return this.resize(this.sessionPtr, inputsArray, dims);
}
public boolean export(String modelFilename, int model_type, int quantization_type) {
return this.export(this.sessionPtr, modelFilename, model_type, quantization_type);
public boolean export(String modelFileName, int modelType, int quantizationType) {
return this.export(this.sessionPtr, modelFileName, modelType, quantizationType);
}
public boolean train() {
@ -176,8 +175,8 @@ public class LiteSession {
return this.isEval(this.sessionPtr);
}
public boolean setLearningRate(float learning_rate) {
return this.setLearningRate(this.sessionPtr, learning_rate);
public boolean setLearningRate(float learningRate) {
return this.setLearningRate(this.sessionPtr, learningRate);
}
public boolean setupVirtualBatch(int virtualBatchMultiplier, float learningRate, float momentum) {
@ -205,11 +204,11 @@ public class LiteSession {
private native long createSessionWithModel(MappedByteBuffer buffer, long msConfigPtr);
private native long createTrainSession(String filename, long msContextPtr, boolean train_mode, long msTrainCfgPtr);
private native long createTrainSession(String fileName, long msContextPtr, boolean trainMode, long msTrainCfgPtr);
private native boolean compileGraph(long sessionPtr, long modelPtr);
private native void bindThread(long sessionPtr, boolean if_bind);
private native void bindThread(long sessionPtr, boolean ifBind);
private native boolean runGraph(long sessionPtr);
@ -229,7 +228,7 @@ public class LiteSession {
private native boolean resize(long sessionPtr, long[] inputs, int[][] dims);
private native boolean export(long sessionPtr, String modelFilename, int model_type, int quantization_type);
private native boolean export(long sessionPtr, String modelFilename, int modelType, int quantizationType);
private native boolean train(long sessionPtr);
@ -239,7 +238,7 @@ public class LiteSession {
private native boolean isEval(long sessionPtr);
private native boolean setLearningRate(long sessionPtr, float learning_rate);
private native boolean setLearningRate(long sessionPtr, float learningRate);
private native boolean setupVirtualBatch(long sessionPtr, int virtualBatchMultiplier, float learningRate, float momentum);

View File

@ -103,13 +103,11 @@ endif()
if(SUPPORT_TRAIN)
set(LITE_TRAIN_SO_NAME mindspore-lite-train minddata-lite)
set(JNI_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp)
add_library(mindspore-lite-train-jni SHARED ${JNI_TRAIN_SRC})
if(PLATFORM_ARM64 OR PLATFORM_ARM32)
find_library(log-lib log)
target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME} ${log-lib})
target_link_libraries(mindspore-lite-jni ${LITE_TRAIN_SO_NAME} ${log-lib})
else()
target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME})
target_link_libraries(mindspore-lite-jni ${LITE_TRAIN_SO_NAME})
endif()
endif()

View File

@ -356,8 +356,8 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_setLea
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_setupVirtualBatch(JNIEnv *env, jobject thiz,
jlong session_ptr,
jint virtualBatchMultiplier,
jfloat learningRate,
jint virtual_batch_factor,
jfloat learning_rate,
jfloat momentum) {
auto *session_pointer = reinterpret_cast<void *>(session_ptr);
if (session_pointer == nullptr) {
@ -365,7 +365,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_setupV
return (jboolean) false;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
auto ret = lite_session_ptr->SetupVirtualBatch(virtualBatchMultiplier, learningRate, momentum);
auto ret = lite_session_ptr->SetupVirtualBatch(virtual_batch_factor, learning_rate, momentum);
return (jboolean)(ret == mindspore::lite::RET_OK);
}
@ -411,3 +411,24 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getFeat
}
return ret;
}
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createTrainSession(JNIEnv *env, jobject thiz,
jstring file_name,
jlong ms_context_ptr,
jboolean train_mode,
jlong train_config_ptr) {
auto *pointer = reinterpret_cast<void *>(ms_context_ptr);
if (pointer == nullptr) {
MS_LOGE("Context pointer from java is nullptr");
return jlong(nullptr);
}
auto *lite_context_ptr = static_cast<mindspore::lite::Context *>(pointer);
auto session = mindspore::session::LiteSession::CreateTrainSession(env->GetStringUTFChars(file_name, JNI_FALSE),
lite_context_ptr, train_mode, nullptr);
if (session == nullptr) {
MS_LOGE("CreateTrainSession failed");
return jlong(nullptr);
}
return jlong(session);
}

View File

@ -1,39 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <jni.h>
#include "common/ms_log.h"
#include "include/lite_session.h"
#include "include/train/train_cfg.h"
#include "include/errorcode.h"
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createTrainSession(
JNIEnv *env, jobject thiz, jstring file_name, jlong ms_context_ptr, jboolean train_mode, jlong train_config_ptr) {
auto *pointer = reinterpret_cast<void *>(ms_context_ptr);
if (pointer == nullptr) {
MS_LOGE("Context pointer from java is nullptr");
return jlong(nullptr);
}
auto *lite_context_ptr = static_cast<mindspore::lite::Context *>(pointer);
auto session = mindspore::session::LiteSession::CreateTrainSession(env->GetStringUTFChars(file_name, JNI_FALSE),
lite_context_ptr, train_mode, nullptr);
if (session == nullptr) {
MS_LOGE("CreateSession failed");
return jlong(nullptr);
}
return jlong(session);
}